From 70b93234ec21cd5a4eaca3d521da3d42c801a15b Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Wed, 1 Nov 2023 10:10:20 +0800 Subject: [PATCH] Implement prototype for torch based fermionic library. --- .github/workflows/CI.yml | 57 + .gitignore | 5 + LICENSE.md | 675 ++++++++++++ README.md | 3 + pyproject.toml | 32 + tat/__init__.py | 6 + tat/_qr.py | 224 ++++ tat/_svd.py | 283 +++++ tat/_utility.py | 37 + tat/compat.py | 439 ++++++++ tat/edge.py | 297 ++++++ tat/tensor.py | 1947 +++++++++++++++++++++++++++++++++++ tests/test_compat.py | 124 +++ tests/test_create_tensor.py | 103 ++ tests/test_edge.py | 39 + tests/test_qr.py | 59 ++ tests/test_svd.py | 40 + 17 files changed, 4370 insertions(+) create mode 100644 .github/workflows/CI.yml create mode 100644 .gitignore create mode 100644 LICENSE.md create mode 100644 README.md create mode 100644 pyproject.toml create mode 100644 tat/__init__.py create mode 100644 tat/_qr.py create mode 100644 tat/_svd.py create mode 100644 tat/_utility.py create mode 100644 tat/compat.py create mode 100644 tat/edge.py create mode 100644 tat/tensor.py create mode 100644 tests/test_compat.py create mode 100644 tests/test_create_tensor.py create mode 100644 tests/test_edge.py create mode 100644 tests/test_qr.py create mode 100644 tests/test_svd.py diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml new file mode 100644 index 000000000..1216edc09 --- /dev/null +++ b/.github/workflows/CI.yml @@ -0,0 +1,57 @@ +name: CI + +on: [push, pull_request] + +jobs: + CI: + + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + - python-version: "3.9" + pytorch-version: "1.12" + - python-version: "3.9" + pytorch-version: "1.13" + - python-version: "3.9" + pytorch-version: "2.0" + - python-version: "3.9" + pytorch-version: "2.1" + + - python-version: "3.10" + pytorch-version: "1.12" + - python-version: "3.10" + pytorch-version: "1.13" + - python-version: "3.10" + pytorch-version: "2.0" + - python-version: "3.10" + pytorch-version: "2.1" + + - python-version: "3.11" + pytorch-version: "1.13" + - python-version: "3.11" + pytorch-version: "2.0" + - python-version: "3.11" + pytorch-version: "2.1" + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install requirements + run: | + pip install pylint==2.17 mypy==1.6 pytest==7.4 pytest-cov==4.1 + pip install torch==${{ matrix.pytorch-version }}+cpu --index-url https://download.pytorch.org/whl/cpu + pip install multimethod + - name: Run pylint + run: pylint tat tests + working-directory: ${{ github.workspace }} + - name: Run mypy + run: mypy tat tests + working-directory: ${{ github.workspace }} + - name: Run pytest + run: pytest + working-directory: ${{ github.workspace }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..f1f50703b --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.coverage +.mypy_cache +__pycache__ +*.pyc +env \ No newline at end of file diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 000000000..496acdb2a --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,675 @@ +# GNU GENERAL PUBLIC LICENSE + +Version 3, 29 June 2007 + +Copyright (C) 2007 Free Software Foundation, Inc. + + +Everyone is permitted to copy and distribute verbatim copies of this +license document, but changing it is not allowed. + +## Preamble + +The GNU General Public License is a free, copyleft license for +software and other kinds of works. + +The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom +to share and change all versions of a program--to make sure it remains +free software for all its users. We, the Free Software Foundation, use +the GNU General Public License for most of our software; it applies +also to any other work released this way by its authors. You can apply +it to your programs, too. + +When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + +To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you +have certain responsibilities if you distribute copies of the +software, or if you modify it: responsibilities to respect the freedom +of others. + +For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + +Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + +For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + +Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the +manufacturer can do so. This is fundamentally incompatible with the +aim of protecting users' freedom to change the software. The +systematic pattern of such abuse occurs in the area of products for +individuals to use, which is precisely where it is most unacceptable. +Therefore, we have designed this version of the GPL to prohibit the +practice for those products. If such problems arise substantially in +other domains, we stand ready to extend this provision to those +domains in future versions of the GPL, as needed to protect the +freedom of users. + +Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish +to avoid the special danger that patents applied to a free program +could make it effectively proprietary. To prevent this, the GPL +assures that patents cannot be used to render the program non-free. + +The precise terms and conditions for copying, distribution and +modification follow. + +## TERMS AND CONDITIONS + +### 0. Definitions. + +"This License" refers to version 3 of the GNU General Public License. + +"Copyright" also means copyright-like laws that apply to other kinds +of works, such as semiconductor masks. + +"The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + +To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of +an exact copy. The resulting work is called a "modified version" of +the earlier work or a work "based on" the earlier work. + +A "covered work" means either the unmodified Program or a work based +on the Program. + +To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + +To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user +through a computer network, with no transfer of a copy, is not +conveying. + +An interactive user interface displays "Appropriate Legal Notices" to +the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + +### 1. Source Code. + +The "source code" for a work means the preferred form of the work for +making modifications to it. "Object code" means any non-source form of +a work. + +A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + +The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + +The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + +The Corresponding Source need not include anything that users can +regenerate automatically from other parts of the Corresponding Source. + +The Corresponding Source for a work in source code form is that same +work. + +### 2. Basic Permissions. + +All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + +You may make, run and propagate covered works that you do not convey, +without conditions so long as your license otherwise remains in force. +You may convey covered works to others for the sole purpose of having +them make modifications exclusively for you, or provide you with +facilities for running those works, provided that you comply with the +terms of this License in conveying all material for which you do not +control copyright. Those thus making or running the covered works for +you must do so exclusively on your behalf, under your direction and +control, on terms that prohibit them from making any copies of your +copyrighted material outside their relationship with you. + +Conveying under any other circumstances is permitted solely under the +conditions stated below. Sublicensing is not allowed; section 10 makes +it unnecessary. + +### 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + +No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + +When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such +circumvention is effected by exercising rights under this License with +respect to the covered work, and you disclaim any intention to limit +operation or modification of the work as a means of enforcing, against +the work's users, your or third parties' legal rights to forbid +circumvention of technological measures. + +### 4. Conveying Verbatim Copies. + +You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + +You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + +### 5. Conveying Modified Source Versions. + +You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these +conditions: + +- a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. +- b) The work must carry prominent notices stating that it is + released under this License and any conditions added under + section 7. This requirement modifies the requirement in section 4 + to "keep intact all notices". +- c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. +- d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + +A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + +### 6. Conveying Non-Source Forms. + +You may convey a covered work in object code form under the terms of +sections 4 and 5, provided that you also convey the machine-readable +Corresponding Source under the terms of this License, in one of these +ways: + +- a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. +- b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the Corresponding + Source from a network server at no charge. +- c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. +- d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. +- e) Convey the object code using peer-to-peer transmission, + provided you inform other peers where the object code and + Corresponding Source of the work are being offered to the general + public at no charge under subsection 6d. + +A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + +A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, +family, or household purposes, or (2) anything designed or sold for +incorporation into a dwelling. In determining whether a product is a +consumer product, doubtful cases shall be resolved in favor of +coverage. For a particular product received by a particular user, +"normally used" refers to a typical or common use of that class of +product, regardless of the status of the particular user or of the way +in which the particular user actually uses, or expects or is expected +to use, the product. A product is a consumer product regardless of +whether the product has substantial commercial, industrial or +non-consumer uses, unless such uses represent the only significant +mode of use of the product. + +"Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to +install and execute modified versions of a covered work in that User +Product from a modified version of its Corresponding Source. The +information must suffice to ensure that the continued functioning of +the modified object code is in no case prevented or interfered with +solely because modification has been made. + +If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + +The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or +updates for a work that has been modified or installed by the +recipient, or for the User Product in which it has been modified or +installed. Access to a network may be denied when the modification +itself materially and adversely affects the operation of the network +or violates the rules and protocols for communication across the +network. + +Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + +### 7. Additional Terms. + +"Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + +When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + +Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders +of that material) supplement the terms of this License with terms: + +- a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or +- b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or +- c) Prohibiting misrepresentation of the origin of that material, + or requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or +- d) Limiting the use for publicity purposes of names of licensors + or authors of the material; or +- e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or +- f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions + of it) with contractual assumptions of liability to the recipient, + for any liability that these contractual assumptions directly + impose on those licensors and authors. + +All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + +If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + +Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; the +above requirements apply either way. + +### 8. Termination. + +You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + +However, if you cease all violation of this License, then your license +from a particular copyright holder is reinstated (a) provisionally, +unless and until the copyright holder explicitly and finally +terminates your license, and (b) permanently, if the copyright holder +fails to notify you of the violation by some reasonable means prior to +60 days after the cessation. + +Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + +Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + +### 9. Acceptance Not Required for Having Copies. + +You are not required to accept this License in order to receive or run +a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + +### 10. Automatic Licensing of Downstream Recipients. + +Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + +An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + +You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + +### 11. Patents. + +A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + +A contributor's "essential patent claims" are all patent claims owned +or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + +Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + +In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + +If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + +If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + +A patent license is "discriminatory" if it does not include within the +scope of its coverage, prohibits the exercise of, or is conditioned on +the non-exercise of one or more of the rights that are specifically +granted under this License. You may not convey a covered work if you +are a party to an arrangement with a third party that is in the +business of distributing software, under which you make payment to the +third party based on the extent of your activity of conveying the +work, and under which the third party grants, to any of the parties +who would receive the covered work from you, a discriminatory patent +license (a) in connection with copies of the covered work conveyed by +you (or copies made from those copies), or (b) primarily for and in +connection with specific products or compilations that contain the +covered work, unless you entered into that arrangement, or that patent +license was granted, prior to 28 March 2007. + +Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + +### 12. No Surrender of Others' Freedom. + +If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under +this License and any other pertinent obligations, then as a +consequence you may not convey it at all. For example, if you agree to +terms that obligate you to collect a royalty for further conveying +from those to whom you convey the Program, the only way you could +satisfy both those terms and this License would be to refrain entirely +from conveying the Program. + +### 13. Use with the GNU Affero General Public License. + +Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + +### 14. Revised Versions of this License. + +The Free Software Foundation may publish revised and/or new versions +of the GNU General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in +detail to address new problems or concerns. + +Each version is given a distinguishing version number. If the Program +specifies that a certain numbered version of the GNU General Public +License "or any later version" applies to it, you have the option of +following the terms and conditions either of that numbered version or +of any later version published by the Free Software Foundation. If the +Program does not specify a version number of the GNU General Public +License, you may choose any version ever published by the Free +Software Foundation. + +If the Program specifies that a proxy can decide which future versions +of the GNU General Public License can be used, that proxy's public +statement of acceptance of a version permanently authorizes you to +choose that version for the Program. + +Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + +### 15. Disclaimer of Warranty. + +THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT +WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND +PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE +DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR +CORRECTION. + +### 16. Limitation of Liability. + +IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR +CONVEYS THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, +INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES +ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT +NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR +LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM +TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER +PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + +### 17. Interpretation of Sections 15 and 16. + +If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + +END OF TERMS AND CONDITIONS + +## How to Apply These Terms to Your New Programs + +If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these +terms. + +To do so, attach the following notices to the program. It is safest to +attach them to the start of each source file to most effectively state +the exclusion of warranty; and each file should have at least the +"copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper +mail. + +If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands \`show w' and \`show c' should show the +appropriate parts of the General Public License. Of course, your +program's commands might be different; for a GUI interface, you would +use an "about box". + +You should also get your employer (if you work as a programmer) or +school, if any, to sign a "copyright disclaimer" for the program, if +necessary. For more information on this, and how to apply and follow +the GNU GPL, see . + +The GNU General Public License does not permit incorporating your +program into proprietary programs. If your program is a subroutine +library, you may consider it more useful to permit linking proprietary +applications with the library. If this is what you want to do, use the +GNU Lesser General Public License instead of this License. But first, +please read . diff --git a/README.md b/README.md new file mode 100644 index 000000000..3370c23ae --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# TAT + +A Fermionic tensor library based on pytorch. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..d69738304 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,32 @@ +[project] +name = "tat" +version = "0.4.0" +authors = [ + {email = "zh970205@mail.ustc.edu.cn", name = "Hao Zhang"} +] +description = "A Fermionic tensor library based on pytorch." +readme = "README.md" +requires-python = ">=3.9" +license = {text = "GPL-3.0-or-later"} +dependencies = [ + "multimethod>=1.9", + "torch>=1.12", +] + +[tool.pylint] +max-line-length = 120 +generated-members = "torch.*" +init-hook="import sys; sys.path.append(\".\")" + +[tool.yapf] +based_on_style = "google" +column_limit = 120 + +[tool.mypy] +check_untyped_defs = true +disallow_untyped_defs = true + +[tool.pytest.ini_options] +pythonpath = "." +testpaths = ["tests",] +addopts = "--cov=tat" diff --git a/tat/__init__.py b/tat/__init__.py new file mode 100644 index 000000000..9f8cc4bcb --- /dev/null +++ b/tat/__init__.py @@ -0,0 +1,6 @@ +""" +The tat is a Fermionic tensor library based on pytorch. +""" + +from .edge import Edge +from .tensor import Tensor diff --git a/tat/_qr.py b/tat/_qr.py new file mode 100644 index 000000000..5073a49c1 --- /dev/null +++ b/tat/_qr.py @@ -0,0 +1,224 @@ +""" +This module implements QR decomposition based on Givens rotation and Householder reflection. +""" + +import typing +import torch + +# pylint: disable=invalid-name + + +def _syminvadj(X: torch.Tensor) -> torch.Tensor: + ret = X + X.H + ret.diagonal().real[:] *= 1 / 2 + return ret + + +def _triliminvadjskew(X: torch.Tensor) -> torch.Tensor: + ret = torch.tril(X - X.H) + if torch.is_complex(X): + ret.diagonal().imag[:] *= 1 / 2 + return ret + + +def _qr_backward( + Q: torch.Tensor, + R: torch.Tensor, + Q_grad: typing.Optional[torch.Tensor], + R_grad: typing.Optional[torch.Tensor], +) -> typing.Optional[torch.Tensor]: + # see https://arxiv.org/pdf/2009.10071.pdf section 4.3 and 4.5 + # see pytorch torch/csrc/autograd/FunctionsManual.cpp:linalg_qr_backward + m = Q.size(0) + n = R.size(1) + + if Q_grad is not None: + if R_grad is not None: + MH = R_grad @ R.H - Q.H @ Q_grad + else: + MH = -Q.H @ Q_grad + else: + if R_grad is not None: + MH = R_grad @ R.H + else: + return None + + # pylint: disable=no-else-return + if m >= n: + # Deep and square matrix + b = Q @ _syminvadj(torch.triu(MH)) + if Q_grad is not None: + b = b + Q_grad + return torch.linalg.solve_triangular(R.H, b, upper=False, left=False) + else: + # Wide matrix + b = Q @ (_triliminvadjskew(-MH)) + result = torch.linalg.solve_triangular(R[:, :m].H, b, upper=False, left=False) + result = torch.cat((result, torch.zeros([m, n - m], dtype=result.dtype, device=result.device)), dim=1) + if R_grad is not None: + result = result + Q @ R_grad + return result + + +class CommonQR(torch.autograd.Function): + """ + Implement the autograd function for QR. + """ + + # pylint: disable=abstract-method + + @staticmethod + def backward( # type: ignore[override] + ctx: typing.Any, + Q_grad: typing.Optional[torch.Tensor], + R_grad: typing.Optional[torch.Tensor], + ) -> typing.Optional[torch.Tensor]: + # pylint: disable=arguments-differ + Q, R = ctx.saved_tensors + return _qr_backward(Q, R, Q_grad, R_grad) + + +def _normalize_diagonal(a: torch.Tensor) -> torch.Tensor: + r = torch.sqrt(a.conj() * a) + return torch.where( + r == torch.zeros([], dtype=a.dtype, device=a.device), + torch.ones([], dtype=a.dtype, device=a.device), + a / r, + ) + + +def _givens_parameter(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + r = torch.sqrt(a.conj() * a + b.conj() * b) + return torch.where( + b == torch.zeros([], dtype=a.dtype, device=a.device), + torch.ones([], dtype=a.dtype, device=a.device), + a / r, + ), torch.where( + b == torch.zeros([], dtype=a.dtype, device=a.device), + torch.zeros([], dtype=a.dtype, device=a.device), + b / r, + ) + + +def _givens_qr(A: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + m, n = A.shape + k = min(m, n) + Q = torch.eye(m, dtype=A.dtype, device=A.device) + R = A.clone(memory_format=torch.contiguous_format) + + # Parallel strategy + # Every row rotated to the nearest row above + for g in range(m - 1, 0, -1): + # rotate R[g, 0], R[g+2, 1], R[g+4, 2], ... + for i, col in zip(range(g, m, 2), range(n)): + j = i - 1 + # Rotate inside column col + # Rotate from row i to row j + c, s = _givens_parameter(R[j, col], R[i, col]) + Q[i], Q[j] = -s * Q[j] + c * Q[i], c.conj() * Q[j] + s.conj() * Q[i] + R[i], R[j] = -s * R[j] + c * R[i], c.conj() * R[j] + s.conj() * R[i] + for g in range(1, k): + # rotate R[g+1, g], R[g+1+2, g+1], R[g+1+4, g+2], ... + for i, col in zip(range(g + 1, m, 2), range(g, n)): + j = i - 1 + # Rotate inside column col + # Rotate from row i to row j + c, s = _givens_parameter(R[j, col], R[i, col]) + Q[i], Q[j] = -s * Q[j] + c * Q[i], c.conj() * Q[j] + s.conj() * Q[i] + R[i], R[j] = -s * R[j] + c * R[i], c.conj() * R[j] + s.conj() * R[i] + + # for j in range(n): + # for i in range(j + 1, m): + # col = j + # # Rotate inside column col + # # Rotate from row i to row j + # c, s = _givens_parameter(R[j, col], R[i, col]) + # Q[i], Q[j] = -s * Q[j] + c * Q[i], c.conj() * Q[j] + s.conj() * Q[i] + # R[i], R[j] = -s * R[j] + c * R[i], c.conj() * R[j] + s.conj() * R[i] + + # Make diagonal positive + c = _normalize_diagonal(R.diagonal()).conj() + Q[:k] *= torch.unsqueeze(c, 1) + R[:k] *= torch.unsqueeze(c, 1) + + Q, R = Q[:k].H, R[:k] + return Q, R + + +class GivensQR(CommonQR): + """ + Compute the reduced QR decomposition using Givens rotation. + """ + + # pylint: disable=abstract-method + + @staticmethod + def forward( # type: ignore[override] + ctx: torch.autograd.function.FunctionCtx, + A: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + # pylint: disable=arguments-differ + Q, R = _givens_qr(A) + ctx.save_for_backward(Q, R) + return Q, R + + +def _normalize_delta(a: torch.Tensor) -> torch.Tensor: + norm = a.norm() + return torch.where( + norm == torch.zeros([], dtype=a.dtype, device=a.device), + torch.zeros([], dtype=a.dtype, device=a.device), + a / norm, + ) + + +def _reflect_target(x: torch.Tensor) -> torch.Tensor: + return torch.norm(x) * _normalize_diagonal(x[0]) + + +def _householder_qr(A: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + m, n = A.shape + k = min(m, n) + Q = torch.eye(m, dtype=A.dtype, device=A.device) + R = A.clone(memory_format=torch.contiguous_format) + + for i in range(k): + x = R[i:, i] + v = torch.zeros_like(x) + # For complex matrix, it require = , i.e. v[0] and x[0] have opposite argument. + v[0] = _reflect_target(x) + # Reflect x to v + delta = _normalize_delta(v - x) + # H = 1 - 2 |Delta> tuple[torch.Tensor, torch.Tensor]: + # pylint: disable=arguments-differ + Q, R = _householder_qr(A) + ctx.save_for_backward(Q, R) + return Q, R + + +givens_qr = GivensQR.apply +householder_qr = HouseholderQR.apply diff --git a/tat/_svd.py b/tat/_svd.py new file mode 100644 index 000000000..ac6fa0bff --- /dev/null +++ b/tat/_svd.py @@ -0,0 +1,283 @@ +""" +This module implements SVD decomposition without Householder reflection. +""" + +import typing +import torch +from ._qr import _normalize_diagonal, _givens_parameter + +# pylint: disable=invalid-name + + +def _svd(A: torch.Tensor, error: float) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # pylint: disable=too-many-locals + # pylint: disable=too-many-branches + # pylint: disable=too-many-statements + # pylint: disable=too-many-nested-blocks + + # see https://web.stanford.edu/class/cme335/lecture6.pdf + m, n = A.shape + trans = False + if m < n: + trans = True + A = A.transpose(0, 1) + m, n = n, m + U = torch.eye(m, dtype=A.dtype, device=A.device) + V = torch.eye(n, dtype=A.dtype, device=A.device) + + # Make bidiagonal matrix + B = A.clone(memory_format=torch.contiguous_format) + for i in range(n): + # (i:, i) + for j in range(m - 1, i, -1): + col = i + # Rotate inside col i + # Rotate from row j to j-1 + c, s = _givens_parameter(B[j - 1, col], B[j, col]) + U[j], U[j - 1] = -s * U[j - 1] + c * U[j], c.conj() * U[j - 1] + s.conj() * U[j] + B[j], B[j - 1] = -s * B[j - 1] + c * B[j], c.conj() * B[j - 1] + s.conj() * B[j] + # x = B[i:, i] + # v = torch.zeros_like(x) + # v[0] = _reflect_target(x) + # delta = _normalize_delta(v - x) + # B[i:, :] -= 2 * torch.outer(delta, delta.conj() @ B[i:, :]) + # U[i:, :] -= 2 * torch.outer(delta, delta.conj() @ U[i:, :]) + + # (i, i+1:)/H + if i == n - 1: + break + for j in range(n - 1, i + 1, -1): + row = i + # Rotate inside row i + # Rotate from col j to j-1 + c, s = _givens_parameter(B[row, j - 1], B[row, j]) + V[j], V[j - 1] = -s * V[j - 1] + c * V[j], c.conj() * V[j - 1] + s.conj() * V[j] + B[:, j], B[:, j - 1] = -s * B[:, j - 1] + c * B[:, j], c.conj() * B[:, j - 1] + s.conj() * B[:, j] + # x = B[i, i + 1:] + # v = torch.zeros_like(x) + # v[0] = _reflect_target(x) + # delta = _normalize_delta(v - x) + # B[:, i + 1:] -= 2 * torch.outer(B[:, i + 1:] @ delta.conj(), delta) + # V[i + 1:, :] -= 2 * torch.outer(delta, delta.conj() @ V[i + 1:, :]) + B = B[:n] + U = U[:n] + # print(B) + # error_decomp = torch.max(torch.abs(U.H @ B @ V.H.T - A)).item() + # assert error_decomp < 1e-4 + + # QR iteration with implicit Q + S = torch.diagonal(B).clone(memory_format=torch.contiguous_format) + F = torch.diagonal(B, offset=1).clone(memory_format=torch.contiguous_format) + F.resize_(S.size(0)) + F[-1] = 0 + X = F[-1] + stack: list[tuple[int, int]] = [(0, n - 1)] + while stack: + # B.zero_() + # B.diagonal()[:] = S + # B.diagonal(offset = 1)[:] = F[:-1] + # error_decomp = torch.max(torch.abs(U.H @ B @ V.H.T - A)).item() + # assert error_decomp < 1e-4 + + low = stack[-1][0] + high = stack[-1][1] + + if low == high: + stack.pop() + continue + + max_diagonal = torch.abs(S[low]) + for b in range(low, high + 1): + Sb = torch.abs(S[b]) + if Sb < max_diagonal: + max_diagonal = Sb + # Check if S[b] is zero + if Sb < error: + # pylint: disable=no-else-continue + if b == low: + X = F[b].clone() + F[b] = 0 + for i in range(b + 1, high + 1): + c, s = _givens_parameter(S[i], X) + U[b], U[i] = -s * U[i] + c * U[b], c.conj() * U[i] + s.conj() * U[b] + + S[i] = c.conj() * S[i] + s.conj() * X + if i != high: + X, F[i] = -s * F[i] + c * X, c.conj() * F[i] + s.conj() * X + stack.pop() + stack.append((b + 1, high)) + stack.append((low, b)) + continue + else: + X = F[b - 1].clone() + F[b - 1] = 0 + for i in range(b - 1, low - 1, -1): + c, s = _givens_parameter(S[i], X) + V[b], V[i] = -s * V[i] + c * V[b], c.conj() * V[i] + s.conj() * V[b] + + S[i] = c.conj() * S[i] + s.conj() * X + if i != low: + X, F[i - 1] = -s * F[i - 1] + c * X, c.conj() * F[i - 1] + s.conj() * X + stack.pop() + stack.append((b, high)) + stack.append((low, b - 1)) + continue + + b = int(torch.argmin(torch.abs(F[low:high]))) + low + if torch.abs(F[b]) < max_diagonal * error: + F[b] = 0 + stack.pop() + stack.append((b + 1, high)) + stack.append((low, b)) + continue + + tdn = (S[b + 1].conj() * S[b + 1] + F[b].conj() * F[b]).real + tdn_1 = (S[b].conj() * S[b] + F[b - 1].conj() * F[b - 1]).real + tsn_1 = F[b].conj() * S[b] + d = (tdn_1 - tdn) / 2 + mu = tdn + d - torch.sign(d) * torch.sqrt(d**2 + tsn_1.conj() * tsn_1) + for i in range(low, high): + if i == low: + c, s = _givens_parameter(S[low].conj() * S[low] - mu, S[low].conj() * F[low]) + else: + c, s = _givens_parameter(F[i - 1], X) + V[i + 1], V[i] = -s * V[i] + c * V[i + 1], c.conj() * V[i] + s.conj() * V[i + 1] + if i != low: + F[i - 1] = c.conj() * F[i - 1] + s.conj() * X + F[i], S[i] = -s * S[i] + c * F[i], c.conj() * S[i] + s.conj() * F[i] + S[i + 1], X = c * S[i + 1], s.conj() * S[i + 1] + + c, s = _givens_parameter(S[i], X) + U[i + 1], U[i] = -s * U[i] + c * U[i + 1], c.conj() * U[i] + s.conj() * U[i + 1] + + S[i] = c.conj() * S[i] + s.conj() * X + S[i + 1], F[i] = -s * F[i] + c * S[i + 1], c.conj() * F[i] + s.conj() * S[i + 1] + if i != high - 1: + F[i + 1], X = c * F[i + 1], s.conj() * F[i + 1] + + # Make diagonal positive + c = _normalize_diagonal(S).conj() + V *= c.unsqueeze(1) # U is larger than V + S *= c + S = S.real + + # Sort + S, order = torch.sort(S, descending=True) + U = U[order] + V = V[order] + + # pylint: disable=no-else-return + if trans: + return V.H, S, U.H.T + else: + return U.H, S, V.H.T + + +def _skew(A: torch.Tensor) -> torch.Tensor: + return A - A.H + + +def _svd_backward( + U: torch.Tensor, + S: torch.Tensor, + Vh: torch.Tensor, + gU: typing.Optional[torch.Tensor], + gS: typing.Optional[torch.Tensor], + gVh: typing.Optional[torch.Tensor], +) -> typing.Optional[torch.Tensor]: + # pylint: disable=too-many-locals + # pylint: disable=too-many-branches + # pylint: disable=too-many-arguments + + # See pytorch torch/csrc/autograd/FunctionsManual.cpp:svd_backward + if gS is None and gU is None and gVh is None: + return None + + m = U.size(0) + n = Vh.size(1) + + if gU is None and gVh is None: + assert gS is not None + # pylint: disable=no-else-return + if m >= n: + return U @ (gS.unsqueeze(1) * Vh) + else: + return (U * gS.unsqueeze(0)) @ Vh + + is_complex = torch.is_complex(U) + + UhgU = _skew(U.H @ gU) if gU is not None else None + VhgV = _skew(Vh @ gVh.H) if gVh is not None else None + + S2 = S * S + E = S2.unsqueeze(0) - S2.unsqueeze(1) + E.diagonal()[:] = 1 + + if gU is not None: + if gVh is not None: + assert UhgU is not None + assert VhgV is not None + gA = (UhgU * S.unsqueeze(0) + S.unsqueeze(1) * VhgV) / E + else: + assert UhgU is not None + gA = (UhgU / E) * S.unsqueeze(0) + else: + assert VhgV is not None + gA = S.unsqueeze(1) * (VhgV / E) + + if gS is not None: + gA = gA + torch.diag(gS) + + if is_complex and gU is not None and gVh is not None: + assert UhgU is not None + gA = gA + torch.diag(UhgU.diagonal() / (2 * S)) + + if m > n and gU is not None: + gA = U @ gA + gUSinv = gU / S.unsqueeze(0) + gA = gA + gUSinv - U @ (U.H @ gUSinv) + gA = gA @ Vh + elif m < n and gVh is not None: + gA = gA @ Vh + SinvgVh = gVh / S.unsqueeze(1) + gA = gA + SinvgVh - (SinvgVh @ Vh.H) @ Vh + gA = U @ gA + elif m >= n: + gA = U @ (gA @ Vh) + else: + gA = (U @ gA) @ Vh + + return gA + + +class SVD(torch.autograd.Function): + """ + Compute SVD decomposition without Householder reflection. + """ + + # pylint: disable=abstract-method + + @staticmethod + def forward( # type: ignore[override] + ctx: torch.autograd.function.FunctionCtx, + A: torch.Tensor, + error: float, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # pylint: disable=arguments-differ + U, S, V = _svd(A, error) + ctx.save_for_backward(U, S, V) + return U, S, V + + @staticmethod + def backward( # type: ignore[override] + ctx: typing.Any, + U_grad: typing.Optional[torch.Tensor], + S_grad: typing.Optional[torch.Tensor], + V_grad: typing.Optional[torch.Tensor], + ) -> tuple[typing.Optional[torch.Tensor], None]: + # pylint: disable=arguments-differ + U, S, V = ctx.saved_tensors + return _svd_backward(U, S, V, U_grad, S_grad, V_grad), None + + +svd = SVD.apply diff --git a/tat/_utility.py b/tat/_utility.py new file mode 100644 index 000000000..0a52f9166 --- /dev/null +++ b/tat/_utility.py @@ -0,0 +1,37 @@ +""" +Some internal utilities used by tat. +""" + +import torch + +# pylint: disable=missing-function-docstring +# pylint: disable=no-else-return + + +def unsqueeze(tensor: torch.Tensor, index: int, rank: int) -> torch.Tensor: + return tensor.reshape([-1 if i == index else 1 for i in range(rank)]) + + +def neg_symmetry(tensor: torch.Tensor) -> torch.Tensor: + if tensor.dtype is torch.bool: + return tensor + else: + return -tensor + + +def add_symmetry(tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor: + if tensor_1.dtype is torch.bool: + return tensor_1 ^ tensor_2 + else: + return tensor_1 + tensor_2 + + +def zero_symmetry(tensor: torch.Tensor) -> torch.Tensor: + return tensor == torch.zeros([], dtype=tensor.dtype) + + +def parity(tensor: torch.Tensor) -> torch.Tensor: + if tensor.dtype is torch.bool: + return tensor + else: + return tensor % 2 != 0 diff --git a/tat/compat.py b/tat/compat.py new file mode 100644 index 000000000..9ea3deb3f --- /dev/null +++ b/tat/compat.py @@ -0,0 +1,439 @@ +""" +This file implements a compat layer for legacy TAT interface. +""" + +from __future__ import annotations +import typing +from multimethod import multimethod +import torch +from .edge import Edge as E +from .tensor import Tensor as T + +# pylint: disable=too-few-public-methods +# pylint: disable=too-many-instance-attributes +# pylint: disable=redefined-outer-name + + +class Symmetry(tuple): + """ + The compat symmetry constructor, without detailed type check. + """ + + def __new__(cls: type[Symmetry], *sym: typing.Any) -> Symmetry: + if len(sym) == 1 and isinstance(sym[0], tuple): + sym = sym[0] + return tuple.__new__(Symmetry, sym) + + def __neg__(self: Symmetry) -> Symmetry: + return Symmetry(tuple(sub_sym if isinstance(sub_sym, bool) else -sub_sym for sub_sym in self)) + + +class CompatSymmetry: + """ + The common Symmetry namespace. + """ + + def __init__(self: CompatSymmetry, fermion: list[bool], dtypes: list[torch.dtype]) -> None: + # This create fake module like TAT.No, TAT.Z2 or similar things, it need to specify the symmetry attributes. + # symmetry is set by two attributes: fermion and dtypes. + self.fermion: list[bool] = fermion + self.dtypes: list[torch.dtype] = dtypes + + # pylint: disable=invalid-name + self.S: CompatScalar + self.D: CompatScalar + self.C: CompatScalar + self.Z: CompatScalar + self.float32: CompatScalar + self.float64: CompatScalar + self.float: CompatScalar + self.complex64: CompatScalar + self.complex128: CompatScalar + self.complex: CompatScalar + + # In old TAT, something like TAT.No.D is a sub module for tensor with specific scalar type. + # In this compat library, it is implemented by another fake module: CompatScalar. + self.S = self.float32 = CompatScalar(self, torch.float32) + self.D = self.float64 = self.float = CompatScalar(self, torch.float64) + self.C = self.complex64 = CompatScalar(self, torch.complex64) + self.Z = self.complex128 = self.complex = CompatScalar(self, torch.complex128) + + self.Edge: CompatEdge = CompatEdge(self) + self.Symmetry: type[Symmetry] = Symmetry + + +class CompatEdge: + """ + The compat edge constructor. + """ + + def __init__(self: CompatEdge, owner: CompatSymmetry) -> None: + self.fermion: list[bool] = owner.fermion + self.dtypes: list[torch.dtype] = owner.dtypes + + def _parse_segments(self: CompatEdge, segments: list) -> tuple[list[torch.Tensor], int]: + # In TAT, user could use [Sym] or [(Sym, Size)] to set segments of a edge, where [(Sym, Size)] is nothing but + # the symmetry and size of every blocks. While [Sym] acts like [(Sym, 1)], so try to treat input as + # [(Sym, Size)] First, if error raised, convert it from [Sym] to [(Sym, 1)] and try again. + try: + # try [(Sym, Size)] first + return self._parse_segments_kernel(segments) + except TypeError: + # Cannot unpack is a type error, value[index] is a type error, too. So only catch TypeError here. + # convert [Sym] to [(Sym, Size)] + return self._parse_segments_kernel([(sym, 1) for sym in segments]) + # This function return the symmetry list and dimension + + def _parse_segments_kernel( + self: CompatEdge, + segments: list[tuple[typing.Any, int]], + ) -> tuple[list[torch.Tensor], int]: + # [(Sym, Size)] for every element + dimension = sum(dim for _, dim in segments) + symmetry = [ + torch.tensor( + # tat.Edge need torch.Tensor as its symmetry, convert it to torch.Tensor with specific dtype. + sum( + # Concat all segment for this sub symmetry from an empty list + # Every segment is just sym[index] * dim, sometimes sym may be sub symmetry itself directly instead + # of tuple of sub symmetry, so call an utility function _parse_segments_get_subsymmetry here. + ([self._parse_segments_get_subsymmetry(sym, index)] * dim + for sym, dim in segments), + [], + ), + dtype=sub_symmetry, + ) + # Generate sub symmetry one by one + for index, sub_symmetry in enumerate(self.dtypes) + ] + return symmetry, dimension + + def _parse_segments_get_subsymmetry(self: CompatEdge, sym: object, index: int) -> object: + # Most of time, symmetry is a tuple of sub symmetry + # But if there is only one sub symmetry in the symmetry, it could not be a tuple but sub symmetry itself. + # pylint: disable=no-else-return + if isinstance(sym, tuple): + # If it is tuple, there is no need to do any other check + return sym[index] + else: + # If it is not tuple, it should be sub symmetry directly, so this symmetry only should own single sub + # symmetry, check it. + if len(self.fermion) == 1: + return sym + else: + raise TypeError(f"{sym=} is not subscript-able") + + @multimethod + def __call__(self: CompatEdge, edge: E) -> E: + """ + Create edge with compat interface. + + It may be created by + 1. Edge(dimension) create trivial symmetry with specified dimension. + 2. Edge(segments, arrow) create with given segments and arrow. + 3. Edge(segments_arrow_tuple) create with a tuple of segments and arrow. + """ + # pylint: disable=invalid-name + return edge + + @__call__.register + def _(self: CompatEdge, dimension: int) -> E: + # Generate a trivial symmetry tuple. In this tuple, every sub symmetry is a torch.zeros tensor with specific + # dtype and the same dimension. + symmetry = [torch.zeros(dimension, dtype=sub_symmetry) for sub_symmetry in self.dtypes] + return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=False) + + @__call__.register + def _(self: CompatEdge, segments: list, arrow: bool = False) -> E: + symmetry, dimension = self._parse_segments(segments) + return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=arrow) + + @__call__.register + def _(self: CompatEdge, segments_and_bool: tuple[list, bool]) -> E: + segments, arrow = segments_and_bool + symmetry, dimension = self._parse_segments(segments) + return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=arrow) + + +class CompatScalar: + """ + The common Scalar namespace. + """ + + def __init__(self: CompatScalar, symmetry: CompatSymmetry, dtype: torch.dtype) -> None: + # This is fake module like TAT.No.D, TAT.Fermi.complex, so it records the parent symmetry information and its + # own dtype. + self.symmetry: CompatSymmetry = symmetry + self.dtype: torch.dtype = dtype + # pylint: disable=invalid-name + self.Tensor: CompatTensor = CompatTensor(self) + + +class CompatTensor: + """ + The compat tensor constructor. + """ + + def __init__(self: CompatTensor, owner: CompatScalar) -> None: + self.symmetry: CompatSymmetry = owner.symmetry + self.dtype: torch.dtype = owner.dtype + self.model: CompatSymmetry = owner.symmetry + self.is_complex: bool = self.dtype.is_complex + self.is_real: bool = self.dtype.is_floating_point + + @multimethod + def __call__(self: CompatTensor, tensor: T) -> T: + """ + Create tensor with compat names and edges. + + It may be create by + 1. Tensor(names, edges) The most used interface. + 2. Tensor() Create a rank-0 tensor, fill with number 1. + 3. Tensor(number, names=[], edge_symmetry=[], edge_arrow=[]) Create a size-1 tensor, with specified edge, and + filled with specified number. + """ + # pylint: disable=invalid-name + return tensor + + @__call__.register + def _(self: CompatTensor, names: list[str], edges: list) -> T: + return T( + names, + [self.symmetry.Edge(edge) for edge in edges], + fermion=self.symmetry.fermion, + dtypes=self.symmetry.dtypes, + dtype=self.dtype, + ) + + @__call__.register + def _(self: CompatTensor) -> T: + result = T( + [], + [], + fermion=self.symmetry.fermion, + dtypes=self.symmetry.dtypes, + data=torch.ones([], dtype=self.dtype), + ) + return result + + @__call__.register + def _( + self: CompatTensor, + number: typing.Any, + names: typing.Optional[list[str]] = None, + edge_symmetry: typing.Optional[list] = None, + edge_arrow: typing.Optional[list[bool]] = None, + ) -> T: + # Create high rank tensor with only one element + if names is None: + names = [] + if edge_symmetry is None: + edge_symmetry = [None for _ in names] + if edge_arrow is None: + edge_arrow = [False for _ in names] + result = T( + names, + [ + # Create edge for every rank, given the only symmetry(maybe None) and arrow. + E( + fermion=self.symmetry.fermion, + dtypes=self.symmetry.dtypes, + # For every edge, its symmetry is a list of all sub symmetry. + symmetry=[ + # For every sub symmetry, get the only symmetry for it, since dimension of all edge is 1. + # It should be noticed that the symmetry may be None, tuple or sub symmetry directly. + torch.tensor([self._create_size1_get_subsymmetry(symmetry, index)], dtype=dtype) + for index, dtype in enumerate(self.symmetry.dtypes) + ], + dimension=1, + arrow=arrow, + ) + for symmetry, arrow in zip(edge_symmetry, edge_arrow) + ], + fermion=self.symmetry.fermion, + dtypes=self.symmetry.dtypes, + data=torch.full([1 for _ in names], number, dtype=self.dtype), + ) + return result + + def _create_size1_get_subsymmetry(self: CompatTensor, sym: object, index: int) -> object: + # pylint: disable=no-else-return + # sym may be None, tuple or sub symmetry directly. + if sym is None: + # If is None, user may want to create symmetric edge with trivial symmetry, which should be 0 for int and + # False for bool, always return 0 here, since it will be converted to correct type by torch.tensor. + return 0 + elif isinstance(sym, tuple): + # If it is tuple, there is no need to do any other check + return sym[index] + else: + # If it is not tuple, it should be sub symmetry directly, so this symmetry only should own single sub + # symmetry, check it. + if len(self.symmetry.fermion) == 1: + return sym + else: + raise TypeError(f"{sym=} is not subscript-able") + + +# Create fake sub module for all symmetry compiled in old version TAT +No: CompatSymmetry = CompatSymmetry(fermion=[], dtypes=[]) +Z2: CompatSymmetry = CompatSymmetry(fermion=[False], dtypes=[torch.bool]) +U1: CompatSymmetry = CompatSymmetry(fermion=[False], dtypes=[torch.int]) +Fermi: CompatSymmetry = CompatSymmetry(fermion=[True], dtypes=[torch.int]) +FermiZ2: CompatSymmetry = CompatSymmetry(fermion=[True, False], dtypes=[torch.int, torch.bool]) +FermiU1: CompatSymmetry = CompatSymmetry(fermion=[True, False], dtypes=[torch.int, torch.int]) +Parity: CompatSymmetry = CompatSymmetry(fermion=[True], dtypes=[torch.bool]) +FermiFermi: CompatSymmetry = CompatSymmetry(fermion=[True, True], dtypes=[torch.int, torch.int]) +Normal: CompatSymmetry = No + +# SJ Dong's convention + + +def arrow(int_arrow: int) -> bool: + "SJ Dong's convention of arrow" + # pylint: disable=no-else-return + if int_arrow == +1: + return False + elif int_arrow == -1: + return True + else: + raise ValueError("int arrow should be +1 or -1.") + + +def parity(int_parity: int) -> bool: + "SJ Dong's convention of parity" + # pylint: disable=no-else-return + if int_parity == +1: + return False + elif int_parity == -1: + return True + else: + raise ValueError("int parity should be +1 or -1.") + + +# Segment index + + +@T._prepare_position.register # pylint: disable=protected-access,no-member +def _(self: T, position: dict[str, tuple[typing.Any, int]]) -> tuple[int, ...]: + return tuple(index_by_point(edge, position[name]) for name, edge in zip(self.names, self.edges)) + + +# Add some compat interface + + +def _compat_function( + focus_type: type, + name: typing.Optional[str] = None, +) -> typing.Callable[[typing.Callable], typing.Callable]: + + def _result(function: typing.Callable) -> typing.Callable: + if name is None: + attr_name = function.__name__ + else: + attr_name = name + setattr(focus_type, attr_name, function) + return function + + return _result + + +@property # type: ignore[misc] +def storage(self: T) -> typing.Any: + "Get the storage of the tensor" + assert self.data.is_contiguous() + return self.data.reshape([-1]) + + +@_compat_function(T, name="storage") # type: ignore[misc] +@storage.setter +def storage(self: T, value: typing.Any) -> None: + "Set the storage of the tensor" + assert self.data.is_contiguous() + self.data.reshape([-1])[:] = torch.as_tensor(value) + + +@_compat_function(T) +def range_(self: T, first: float = 0, step: float = 1) -> T: + "Compat Interface: Set range inplace for this tensor." + result = self.range(first, step) + self._data = result._data # pylint: disable=protected-access + return self + + +@_compat_function(T) +def identity_(self: T, pairs: set[tuple[str, str]]) -> T: + "Compat Interface: Set idenity inplace for this tensor." + result = self.identity(pairs).transpose(self.names) + self._data = result._data # pylint: disable=protected-access + return self + + +# Exponential arguments + +origin_exponential = T.exponential + + +@_compat_function(T) +def exponential(self: T, pairs: set[tuple[str, str]], step: typing.Optional[int] = None) -> T: + "Compat Interface: Get the exponential tensor of this tensor." + # pylint: disable=unused-argument + return origin_exponential(self, pairs) + + +# Edge point conversion + + +@_compat_function(E) +def index_by_point(self: E, point: tuple[typing.Any, int]) -> int: + "Get index by point on an edge" + sym, sub_index = point + if not isinstance(sym, tuple): + sym = (sym,) + for total_index in range(self.dimension): + if all(sub_sym == sub_symmetry[total_index] for sub_sym, sub_symmetry in zip(sym, self.symmetry)): + if sub_index == 0: + return total_index + sub_index = sub_index - 1 + raise ValueError("Invalid input point") + + +@_compat_function(E) +def point_by_index(self: E, index: int) -> tuple[typing.Any, int]: + "Get point by index on an edge" + sym = Symmetry(tuple(sub_symmetry[index] for sub_symmetry in self.symmetry)) + sub_index = sum( + 1 for i in range(index) if all(sub_sym == sub_symmetry[i] for sub_sym, sub_symmetry in zip(sym, self.symmetry))) + return sym, sub_index + + +# Random utility + + +class CompatRandom: + """ + Fake module for compat random utility in TAT. + """ + + def uniform_int(self: CompatRandom, low: int, high: int) -> typing.Callable[[], int]: + "Generator for integer uniform distribution" + # Mypy cannot recognize item of int64 tensor is int, so cast it manually. + return staticmethod( # type: ignore[return-value] # python3.9 does not treat staticmethod as callable + lambda: int(torch.randint(low, high + 1, [], dtype=torch.int64).item())) + + def uniform_real(self: CompatRandom, low: float, high: float) -> typing.Callable[[], float]: + "Generator for float uniform distribution" + return staticmethod( # type: ignore[return-value] # python3.9 does not treat staticmethod as callable + lambda: torch.rand([], dtype=torch.float64).item() * (high - low) + low) + + def normal(self: CompatRandom, mean: float, stddev: float) -> typing.Callable[[], float]: + "Generator for float normal distribution" + return staticmethod( # type: ignore[return-value] # python3.9 does not treat staticmethod as callable + lambda: torch.normal(mean, stddev, [], dtype=torch.float64).item()) + + def seed(self: CompatRandom, new_seed: int) -> None: + "Set the seed for random generator manually" + torch.manual_seed(new_seed) + + +random = CompatRandom() diff --git a/tat/edge.py b/tat/edge.py new file mode 100644 index 000000000..c9a603bb3 --- /dev/null +++ b/tat/edge.py @@ -0,0 +1,297 @@ +""" +This file contains the definition of tensor edge. +""" + +from __future__ import annotations +import functools +import operator +import typing +import torch +from . import _utility + + +class Edge: + """ + The edge type of tensor. + """ + + __slots__ = "_fermion", "_dtypes", "_symmetry", "_dimension", "_arrow", "_parity" + + @property + def fermion(self: Edge) -> list[bool]: + """ + A list records whether every sub symmetry is fermionic. Its length is the number of sub symmetry. + """ + return self._fermion + + @property + def dtypes(self: Edge) -> list[torch.dtype]: + """ + A list records the basic dtype of every sub symmetry. Its length is the number of sub symmetry. + """ + return self._dtypes + + @property + def symmetry(self: Edge) -> list[torch.Tensor]: + """ + A list containing all symmetry of this edge. Its length is the number of sub symmetry. Every element of it is a + sub symmetry. + """ + return self._symmetry + + @property + def dimension(self: Edge) -> int: + """ + The dimension of this edge. + """ + return self._dimension + + @property + def arrow(self: Edge) -> bool: + """ + The arrow of this edge. + """ + return self._arrow + + @property + def parity(self: Edge) -> torch.Tensor: + """ + The parity of this edge. + """ + return self._parity + + def __init__( + self: Edge, + *, + fermion: typing.Optional[list[bool]] = None, + dtypes: typing.Optional[list[torch.dtype]] = None, + symmetry: typing.Optional[list[torch.Tensor]] = None, + dimension: typing.Optional[int] = None, + arrow: typing.Optional[bool] = None, + # The following argument is not public + parity: typing.Optional[torch.Tensor] = None, + ) -> None: + """ + Create an edge with essential information. + + Examples: + - Edge(dimension=5) + - Edge(symmetry=[torch.tensor([False, False, True, True])]) + - Edge(fermion=[False, True], symmetry=[torch.tensor([False, True]), torch.tensor([False, True])], arrow=True) + + Parameters + ---------- + fermion : list[bool], optional + Whether each sub symmetry is fermionic symmetry, its length should be the same to symmetry. But it could be + left empty, if so, a total bosonic edge will be created. + dtypes : list[torch.dtype], optional + The basic dtype to identify each sub symmetry, its length should be the same to symmetry, and it is nothing + but the dtypes of each tensor in the symmetry. It could be left empty, if so, it will be derived from + symmetry. + symmetry : list[torch.Tensor], optional + The symmetry information of every sub symmetry, each of sub symmetry should be a one dimensional tensor with + the same length dimension, and their dtype should be integral type, aka, int or bool. + dimension : int, optional + The dimension of the edge, if not specified, dimension will be detected from symmetry. + arrow : bool, optional + The arrow direction of the edge, it is essential for fermionic edge, aka, an edge with fermionic sub + symmetry. + """ + # pylint: disable=too-many-arguments + + # Symmetry could be left empty to create no symmetry edge + if symmetry is None: + symmetry = [] + + # Fermion could be empty if it is total bosonic edge + if fermion is None: + fermion = [False for _ in symmetry] + + # Dtypes could be empty and derived from symmetry + if dtypes is None: + dtypes = [sub_symmetry.dtype for sub_symmetry in symmetry] + # Check dtype is compatible with symmetry + assert all(sub_symmetry.dtype is sub_dtype for sub_symmetry, sub_dtype in zip(symmetry, dtypes)) + # Check dtype is valid, aka, bool or int + assert all(not (sub_symmetry.is_floating_point() or sub_symmetry.is_complex()) for sub_symmetry in symmetry) + + # The fermion, dtypes and symmetry information should have the same length + assert len(fermion) == len(dtypes) == len(symmetry) + + # If dimension not set, get dimension from symmetry + if dimension is None: + dimension = len(symmetry[0]) + # Check if the dimensions of different sub_symmetry mismatch + assert all(sub_symmetry.size() == (dimension,) for sub_symmetry in symmetry) + + if arrow is None: + # Arrow not set, it should be bosonic edge. + arrow = False + assert not any(fermion) + + self._fermion: list[bool] = fermion + self._dtypes: list[torch.dtype] = dtypes + self._symmetry: list[torch.Tensor] = symmetry + self._dimension: int = dimension + self._arrow: bool = arrow + + if parity is None: + parity = self._generate_parity() + self._parity: torch.Tensor = parity + assert self.parity.size() == (self.dimension,) + assert self.parity.dtype is torch.bool + + def _generate_parity(self: Edge) -> torch.Tensor: + return functools.reduce( + # Reduce sub parity for all sub symmetry by logical xor + torch.logical_xor, + ( + # The parity of sub symmetry + _utility.parity(sub_symmetry) + # Loop all sub symmetry + for sub_symmetry, sub_fermion in zip(self.symmetry, self.fermion) + # But only reduce if it is fermion sub symmetry + if sub_fermion), + # Reduce with start as tensor filled with False + torch.zeros(self.dimension, dtype=torch.bool), + ) + + def conjugate(self: Edge) -> Edge: + """ + Get the conjugated edge. + + Returns + ------- + Edge + The conjugated edge. + """ + # The only two difference of conjugated edge is symmetry and arrow + return Edge( + fermion=self.fermion, + dtypes=self.dtypes, + symmetry=[ + _utility.neg_symmetry(sub_symmetry) # bool -> same, int -> neg + for sub_symmetry in self.symmetry + ], + dimension=self.dimension, + arrow=not self.arrow, + parity=self.parity, + ) + + def __eq__(self: Edge, other: typing.Any) -> bool: + if not isinstance(other, Edge): + # pylint: disable=no-else-return + if torch.jit.is_scripting(): + return False + else: + return NotImplemented + return ( + # Compare int dimension and bool arrow first since they are fast to compare + self.dimension == other.dimension and + # But even if arrow are different, if it is bosonic edge, it is also OK + (self.arrow == other.arrow or not any(self.fermion)) and + # Then the list of bool are compared + self.fermion == other.fermion and + # Then the list of dtypes are compared + self.dtypes == other.dtypes and + # All of symmetries are compared at last, since it is biggest + all( + torch.equal(self_sub_symmetry, other_sub_symmetry) + for self_sub_symmetry, other_sub_symmetry in zip(self.symmetry, other.symmetry))) + + def __str__(self: Edge) -> str: + # pylint: disable=no-else-return + if any(self.fermion): + # Fermionic edge + fermion = ','.join(str(sub_fermion) for sub_fermion in self.fermion) + symmetry = ','.join( + f"[{','.join(str(sub_sym.item()) for sub_sym in sub_symmetry)}]" for sub_symmetry in self.symmetry) + return f"(dimension={self.dimension}, arrow={self.arrow}, fermion=({fermion}), symmetry=({symmetry}))" + elif self.fermion: + # Bosonic edge + symmetry = ','.join( + f"[{','.join(str(sub_sym.item()) for sub_sym in sub_symmetry)}]" for sub_symmetry in self.symmetry) + return f"(dimension={self.dimension}, symmetry=({symmetry}))" + else: + # Trivial edge + return f"(dimension={self.dimension})" + + def __repr__(self: Edge) -> str: + return f"Edge{self.__str__()}" + + @staticmethod + def merge_edges( + edges: list[Edge], + *, + fermion: typing.Optional[list[bool]] = None, + dtypes: typing.Optional[list[torch.dtype]] = None, + arrow: typing.Optional[bool] = None, + ) -> Edge: + """ + Merge several edges into one edge. + + Parameters + ---------- + edges : list[Edge] + The edges to be merged. + fermion : list[bool], optional + Whether each sub symmetry is fermionic, it could be left empty to derive from edges + dtypes : list[torch.dtype], optional + The base type of sub symmetry, it could be left empty to derive from edges + arrow : bool, optional + The arrow of all the edges, it is useful if edges is empty. + + Returns + ------- + Edge + The result edge merged by edges. + """ + # If fermion not set, get it from edges + if fermion is None: + fermion = edges[0].fermion + # All edge should share the same fermion + assert all(fermion == edge.fermion for edge in edges) + # If dtypes not set, get it from edges + if dtypes is None: + dtypes = edges[0].dtypes + # All edge should share the same dtypes + assert all(dtypes == edge.dtypes for edge in edges) + # If arrow set, check it directly, if not set, set to False or get from edges + if arrow is None: + if any(fermion): + # It is fermionic edge. + arrow = edges[0].arrow + else: + # It is bosonic edge, set to False directly since it is useless. + arrow = False + # All edge should share the same arrow for fermionic edge + assert (not any(fermion)) or all(arrow == edge.arrow for edge in edges) + + rank = len(edges) + # Merge edge + dimension = functools.reduce(operator.mul, (edge.dimension for edge in edges), 1) + symmetry = [ + # Every merged sub symmetry is calculated by reduce and flatten + functools.reduce( + # The reduce operator depend on the dtype of this sub symmetry + _utility.add_symmetry, + ( + # The sub symmetry of every edge will be reshape to be reduced. + _utility.unsqueeze(edge.symmetry[sub_symmetry_index], current_index, rank) + # The sub symmetry of every edge is reduced one by one + for current_index, edge in enumerate(edges)), + # Reduce from a rank-0 tensor + torch.zeros([], dtype=sub_symmetry_dtype), + ).reshape([-1]) + # Merge every sub symmetry one by one + for sub_symmetry_index, sub_symmetry_dtype in enumerate(dtypes) + ] + + # parity not set here since it need recalculation + return Edge( + fermion=fermion, + dtypes=dtypes, + symmetry=symmetry, + dimension=dimension, + arrow=arrow, + ) diff --git a/tat/tensor.py b/tat/tensor.py new file mode 100644 index 000000000..bd7de0eff --- /dev/null +++ b/tat/tensor.py @@ -0,0 +1,1947 @@ +""" +This file defined the core tensor type for tat package. +""" + +from __future__ import annotations +import typing +import operator +import functools +from multimethod import multimethod +import torch +from . import _utility +from ._qr import givens_qr, householder_qr # pylint: disable=unused-import +from ._svd import svd as manual_svd # pylint: disable=unused-import +from .edge import Edge + +# pylint: disable=too-many-public-methods +# pylint: disable=too-many-lines + + +class Tensor: + """ + The main tensor type, which wraps pytorch tensor and provides edge names and Fermionic functions. + """ + + __slots__ = "_fermion", "_dtypes", "_names", "_edges", "_data", "_mask" + + def __str__(self: Tensor) -> str: + return f"(names={self.names}, edges={self.edges}, data={self.data})" + + def __repr__(self: Tensor) -> str: + return f"Tensor(names={self.names}, edges={self.edges})" + + @property + def fermion(self: Tensor) -> list[bool]: + """ + A list records whether every sub symmetry is fermionic. Its length is the number of sub symmetry. + """ + return self._fermion + + @property + def dtypes(self: Tensor) -> list[torch.dtype]: + """ + A list records the basic dtype of every sub symmetry. Its length is the number of sub symmetry. + """ + return self._dtypes + + @property + def names(self: Tensor) -> list[str]: + """ + The edge names of this tensor. + """ + return self._names + + @property + def edges(self: Tensor) -> list[Edge]: + """ + The edges information of this tensor. + """ + return self._edges + + @property + def data(self: Tensor) -> torch.Tensor: + """ + The content data of this tensor. + """ + return self._data + + @property + def mask(self: Tensor) -> torch.Tensor: + """ + The content data mask of this tensor. + """ + return self._mask + + @property + def rank(self: Tensor) -> int: + """ + The rank of this tensor. + """ + return len(self._names) + + @property + def dtype(self: Tensor) -> torch.dtype: + """ + The data type of the content in this tensor. + """ + return self.data.dtype + + @property + def btype(self: Tensor) -> str: + """ + The data type of the content in this tensor, represented in BLAS/LAPACK convention. + """ + if self.dtype is torch.float32: + return 'S' + if self.dtype is torch.float64: + return 'D' + if self.dtype is torch.complex64: + return 'C' + if self.dtype is torch.complex128: + return 'Z' + return '?' + + @property + def is_complex(self: Tensor) -> bool: + """ + Whether it is a complex tensor + """ + return self.dtype.is_complex + + @property + def is_real(self: Tensor) -> bool: + """ + Whether it is a real tensor + """ + return self.dtype.is_floating_point + + def edge_by_name(self: Tensor, name: str) -> Edge: + """ + Get edge by the edge name of this tensor. + + Parameters + ---------- + name : str + The given edge name. + + Returns + ------- + Edge + The edge with the given edge name. + """ + assert name in self.names + return self.edges[self.names.index(name)] + + def _arithmetic_operator(self: Tensor, other: object, operate: typing.Callable) -> Tensor: + new_data: torch.Tensor + if isinstance(other, Tensor): + # If it is tensor, check same shape and transpose before calculating. + assert self.same_shape_with(other) + new_data = operate(self.data, other.transpose(self.names).data) + if operate is torch.div: + # In div, it may generate nan + new_data = torch.where(self.mask, new_data, torch.zeros([], dtype=self.dtype)) + else: + # Otherwise treat other as a scalar, mask should be applied later. + new_data = operate(self.data, other) + new_data = torch.where(self.mask, new_data, torch.zeros([], dtype=self.dtype)) + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=new_data, + mask=self.mask, + ) + + def __add__(self: Tensor, other: object) -> Tensor: + return self._arithmetic_operator(other, torch.add) + + def __sub__(self: Tensor, other: object) -> Tensor: + return self._arithmetic_operator(other, torch.sub) + + def __mul__(self: Tensor, other: object) -> Tensor: + return self._arithmetic_operator(other, torch.mul) + + def __truediv__(self: Tensor, other: object) -> Tensor: + return self._arithmetic_operator(other, torch.div) + + def _right_arithmetic_operator(self: Tensor, other: object, operate: typing.Callable) -> Tensor: + new_data: torch.Tensor + if isinstance(other, Tensor): + # If it is tensor, check same shape and transpose before calculating. + assert self.same_shape_with(other) + new_data = operate(other.transpose(self.names).data, self.data) + if operate is torch.div: + # In div, it may generate nan + new_data = torch.where(self.mask, new_data, torch.zeros([], dtype=self.dtype)) + else: + # Otherwise treat other as a scalar, mask should be applied later. + new_data = operate(other, self.data) + new_data = torch.where(self.mask, new_data, torch.zeros([], dtype=self.dtype)) + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=new_data, + mask=self.mask, + ) + + def __radd__(self: Tensor, other: object) -> Tensor: + return self._right_arithmetic_operator(other, torch.add) + + def __rsub__(self: Tensor, other: object) -> Tensor: + return self._right_arithmetic_operator(other, torch.sub) + + def __rmul__(self: Tensor, other: object) -> Tensor: + return self._right_arithmetic_operator(other, torch.mul) + + def __rtruediv__(self: Tensor, other: object) -> Tensor: + return self._right_arithmetic_operator(other, torch.div) + + def _inplace_arithmetic_operator(self: Tensor, other: object, operate: typing.Callable) -> Tensor: + if isinstance(other, Tensor): + # If it is tensor, check same shape and transpose before calculating. + assert self.same_shape_with(other) + operate(self.data, other.transpose(self.names).data, out=self.data) + if operate is torch.div: + # In div, it may generate nan + torch.where(self.mask, self.data, torch.zeros([], dtype=self.dtype), out=self.data) + else: + # Otherwise treat other as a scalar, mask should be applied later. + operate(self.data, other, out=self.data) + torch.where(self.mask, self.data, torch.zeros([], dtype=self.dtype), out=self.data) + return self + + def __iadd__(self: Tensor, other: object) -> Tensor: + return self._inplace_arithmetic_operator(other, torch.add) + + def __isub__(self: Tensor, other: object) -> Tensor: + return self._inplace_arithmetic_operator(other, torch.sub) + + def __imul__(self: Tensor, other: object) -> Tensor: + return self._inplace_arithmetic_operator(other, torch.mul) + + def __itruediv__(self: Tensor, other: object) -> Tensor: + return self._inplace_arithmetic_operator(other, torch.div) + + def __pos__(self: Tensor) -> Tensor: + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=+self.data, + mask=self.mask, + ) + + def __neg__(self: Tensor) -> Tensor: + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=-self.data, + mask=self.mask, + ) + + def __float__(self: Tensor) -> float: + return float(self.data) + + def __complex__(self: Tensor) -> complex: + return complex(self.data) + + def norm(self: Tensor, order: typing.Any) -> float: + """ + Get the norm of the tensor, this function will flatten tensor first before calculate norm. + + Parameters + ---------- + order + The order of norm. + + Returns + ------- + float + The norm of the tensor. + """ + return torch.linalg.vector_norm(self.data, ord=order) + + def norm_max(self: Tensor) -> float: + "max norm" + return self.norm(+torch.inf) + + def norm_min(self: Tensor) -> float: + "min norm" + return self.norm(-torch.inf) + + def norm_num(self: Tensor) -> float: + "0-norm" + return self.norm(0) + + def norm_sum(self: Tensor) -> float: + "1-norm" + return self.norm(1) + + def norm_2(self: Tensor) -> float: + "2-norm" + return self.norm(2) + + def copy(self: Tensor) -> Tensor: + """ + Get a copy of this tensor + + Returns + ------- + Tensor + The copy of this tensor + """ + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=torch.clone(self.data, memory_format=torch.contiguous_format), + mask=self.mask, + ) + + def __copy__(self: Tensor) -> Tensor: + return self.copy() + + def __deepcopy__(self: Tensor, _: typing.Any = None) -> Tensor: + return self.copy() + + def same_shape(self: Tensor) -> Tensor: + """ + Get a tensor with same shape to this tensor + + Returns + ------- + Tensor + A new tensor with the same shape to this tensor + """ + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=torch.zeros_like(self.data), + mask=self.mask, + ) + + def zero_(self: Tensor) -> Tensor: + """ + Set all element to zero in this tensor + + Returns + ------- + Tensor + Return this tensor itself. + """ + self.data.zero_() + return self + + def sqrt(self: Tensor) -> Tensor: + """ + Get the sqrt of the tensor. + + Returns + ------- + Tensor + The sqrt of this tensor. + """ + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=torch.sqrt(torch.abs(self.data)), + mask=self.mask, + ) + + def reciprocal(self: Tensor) -> Tensor: + """ + Get the reciprocal of the tensor. + + Returns + ------- + Tensor + The reciprocal of this tensor. + """ + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=torch.where(self.data == 0, self.data, 1 / self.data), + mask=self.mask, + ) + + def range(self: Tensor, first: typing.Any = 0, step: typing.Any = 1) -> Tensor: + """ + A useful function to Get tensor filled with simple data for test in the same shape. + + Parameters + ---------- + first, step + Parameters to generate data. + + Returns + ------- + Tensor + Returns the tensor filled with simple data for test. + """ + data = torch.cumsum(self.mask.reshape([-1]), dim=0, dtype=self.dtype).reshape(self.data.size()) + data = (data - 1) * step + first + data = torch.where(self.mask, data, torch.zeros([], dtype=self.dtype)) + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=data, + mask=self.mask, + ) + + def to(self: Tensor, new_type: typing.Any) -> Tensor: + """ + Convert this tensor to other scalar type. + + Parameters + ---------- + new_type + The scalar data type of the new tensor. + """ + # pylint: disable=invalid-name + if new_type is int: + new_type = torch.int64 + if new_type is float: + new_type = torch.float64 + if new_type is complex: + new_type = torch.complex128 + if isinstance(new_type, str): + if new_type in ["float32", "S"]: + new_type = torch.float32 + elif new_type in ["float64", "float", "D"]: + new_type = torch.float64 + elif new_type in ["complex64", "C"]: + new_type = torch.complex64 + elif new_type in ["complex128", "complex", "Z"]: + new_type = torch.complex128 + if self.dtype.is_complex and new_type.is_floating_point: + data = self.data.real.to(new_type) + else: + data = self.data.to(new_type) + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=data, + mask=self.mask, + ) + + def __init__( + self: Tensor, + names: list[str], + edges: list[Edge], + *, + dtype: typing.Optional[torch.dtype] = None, + fermion: typing.Optional[list[bool]] = None, + dtypes: typing.Optional[list[torch.dtype]] = None, + # The following argument is not public + mask: typing.Optional[torch.Tensor] = None, + data: typing.Optional[torch.Tensor] = None, + ) -> None: + """ + Create a tensor with specific shape. + + Parameters + ---------- + names : list[str] + The edge names of the tensor, which length is just the tensor rank. + edges : list[Edge] + The detail information of each edge, which length is just the tensor rank. + dtype : torch.dtype, optional + The dtype of the tensor, left it empty to let pytorch choose default dtype. + fermion : list[bool], optional + Whether each sub symmetry is fermionic, it could be left empty to derive from edges + dtypes : list[torch.dtype], optional + The base type of sub symmetry, it could be left empty to derive from edges + """ + # Check the rank is correct in names and edges + assert len(names) == len(edges) + # Check whether there are duplicated names + assert len(set(names)) == len(names) + # If fermion not set, get it from edges + if fermion is None: + fermion = edges[0].fermion + # If dtypes not set, get it from edges + if dtypes is None: + dtypes = edges[0].dtypes + # Check if fermion is correct + assert all(edge.fermion == fermion for edge in edges) + # Check if dtypes is correct + assert all(edge.dtypes == dtypes for edge in edges) + + self._fermion: list[bool] = fermion + self._dtypes: list[torch.dtype] = dtypes + self._names: list[str] = names + self._edges: list[Edge] = edges + + self._data: torch.Tensor + if data is None: + if dtype is None: + self._data = torch.zeros([edge.dimension for edge in self.edges]) + else: + self._data = torch.zeros([edge.dimension for edge in self.edges], dtype=dtype) + else: + self._data = data + assert self.data.size() == tuple(edge.dimension for edge in self.edges) + assert dtype is None or self.dtype is dtype + + self._mask: torch.Tensor + if mask is None: + self._mask = self._generate_mask() + else: + self._mask = mask + assert self.mask.size() == tuple(edge.dimension for edge in self.edges) + assert self.mask.dtype is torch.bool + + def _generate_mask(self: Tensor) -> torch.Tensor: + return functools.reduce( + # Mask is valid if all sub symmetry give valid mask. + torch.logical_and, + ( + # The mask is valid if total symmetry is False or total symmetry is 0 + _utility.zero_symmetry( + # total sub symmetry is calculated by reduce + functools.reduce( + # The reduce operator depend on the dtype of this sub symmetry + _utility.add_symmetry, + ( + # The sub symmetry of every edge will be reshape to be reduced. + _utility.unsqueeze(edge.symmetry[sub_symmetry_index], current_index, self.rank) + # The sub symmetry of every edge is reduced one by one + for current_index, edge in enumerate(self.edges)), + # Reduce from a rank-0 tensor + torch.zeros([], dtype=sub_symmetry_dtype), + )) + # Calculate mask on every sub symmetry one by one + for sub_symmetry_index, sub_symmetry_dtype in enumerate(self.dtypes)), + # Reduce from all true mask + torch.ones(self.data.size(), dtype=torch.bool), + ) + + @multimethod + def _prepare_position(self: Tensor, position: tuple[int, ...]) -> tuple[int, ...]: + indices: tuple[int, ...] = position + assert len(indices) == self.rank + assert all(0 <= index < edge.dimension for edge, index in zip(self.edges, indices)) + return indices + + @_prepare_position.register + def _(self: Tensor, position: tuple[slice, ...]) -> tuple[int, ...]: + index_by_name: dict[str, int] = {s.start: s.stop for s in position} + indices: tuple[int, ...] = tuple(index_by_name[name] for name in self.names) + assert len(indices) == self.rank + assert all(0 <= index < edge.dimension for edge, index in zip(self.edges, indices)) + return indices + + @_prepare_position.register + def _(self: Tensor, position: dict[str, int]) -> tuple[int, ...]: + indices: tuple[int, ...] = tuple(position[name] for name in self.names) + assert len(indices) == self.rank + assert all(0 <= index < edge.dimension for edge, index in zip(self.edges, indices)) + return indices + + def __getitem__(self: Tensor, position: tuple[int, ...] | tuple[slice, ...] | dict[str, int]) -> typing.Any: + """ + Get the element of the tensor + + Parameters + ---------- + position : tuple[int, ...] | tuple[slice, ...] | dict[str, int] + The position of the element, which could be either tuple of index directly or a map from edge name to the + index in the corresponding edge. + """ + indices: tuple[int, ...] = self._prepare_position(position) + return self.data[indices] + + def __setitem__(self: Tensor, position: tuple[int, ...] | tuple[slice, ...] | dict[str, int], + value: typing.Any) -> None: + """ + Set the element of the tensor + + Parameters + ---------- + position : tuple[int, ...] | tuple[slice, ...] | dict[str, int] + The position of the element, which could be either tuple of index directly or a map from edge name to the + index in the corresponding edge. + """ + indices: tuple[int, ...] = self._prepare_position(position) + if self.mask[indices]: + self.data[indices] = value + else: + raise IndexError("The indices specified are masked, so it is invalid to set item here.") + + def clear_symmetry(self: Tensor) -> Tensor: + """ + Clear all symmetry of this tensor. + + Returns + ------- + Tensor + The result tensor with symmetry cleared. + """ + # Mask must be generated again here + # pylint: disable=no-else-return + if any(self.fermion): + return Tensor( + names=self.names, + edges=[ + Edge( + fermion=[True], + dtypes=[torch.bool], + symmetry=[edge.parity], + dimension=edge.dimension, + arrow=edge.arrow, + parity=edge.parity, + ) for edge in self.edges + ], + fermion=[True], + dtypes=[torch.bool], + data=self.data, + ) + else: + return Tensor( + names=self.names, + edges=[ + Edge( + fermion=[], + dtypes=[], + symmetry=[], + dimension=edge.dimension, + arrow=edge.arrow, + parity=edge.parity, + ) for edge in self.edges + ], + fermion=[], + dtypes=[], + data=self.data, + ) + + def randn_(self: Tensor, mean: float = 0., std: float = 1.) -> Tensor: + """ + Fill the tensor with random number in normal distribution. + + Parameters + ---------- + mean, std : float + The parameter of normal distribution. + + Returns + ------- + Tensor + Return this tensor itself. + """ + self.data.normal_(mean, std) + torch.where(self.mask, self.data, torch.zeros([], dtype=self.dtype), out=self.data) + return self + + def rand_(self: Tensor, low: float = 0., high: float = 1.) -> Tensor: + """ + Fill the tensor with random number in uniform distribution. + + Parameters + ---------- + low, high : float + The parameter of uniform distribution. + + Returns + ------- + Tensor + Return this tensor itself. + """ + self.data.uniform_(low, high) + torch.where(self.mask, self.data, torch.zeros([], dtype=self.dtype), out=self.data) + return self + + def same_type_with(self: Tensor, other: Tensor) -> bool: + """ + Check whether two tensor has the same type, that is to say they share the same symmetry type. + """ + return self.fermion == other.fermion and self.dtypes == other.dtypes + + def same_shape_with(self: Tensor, other: Tensor, *, allow_transpose: bool = True) -> bool: + """ + Check whether two tensor has the same shape, that is to say the only differences between them are transpose and + data difference. + """ + if not self.same_type_with(other): + return False + # pylint: disable=no-else-return + if allow_transpose: + return dict(zip(self.names, self.edges)) == dict(zip(other.names, other.edges)) + else: + return self.names == other.names and self.edges == other.edges + + def edge_rename(self: Tensor, name_map: dict[str, str]) -> Tensor: + """ + Rename edge name for this tensor. + + Parameters + ---------- + name_map : dict[str, str] + The name map to be used in renaming edge name. + + Returns + ------- + Tensor + The tensor with names renamed. + """ + return Tensor( + names=[name_map.get(name, name) for name in self.names], + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=self.data, + mask=self.mask, + ) + + def transpose(self: Tensor, names: list[str]) -> Tensor: + """ + Transpose the tensor out-place. + + Parameters + ---------- + names : list[str] + The new edge order identified by edge names. + + Returns + ------- + Tensor + The transpose tensor. + """ + if names == self.names: + return self + assert len(names) == len(self.names) + assert set(names) == set(self.names) + before_by_after = [self.names.index(name) for name in names] + after_by_before = [names.index(name) for name in self.names] + data = self.data.permute(before_by_after) + mask = self.mask.permute(before_by_after) + if any(self.fermion): + # It is fermionic tensor + parities_before_transpose = [ + _utility.unsqueeze(edge.parity, current_index, self.rank) + for current_index, edge in enumerate(self.edges) + ] + # Generate parity by xor all inverse pairs + parity = functools.reduce( + torch.logical_xor, + ( + torch.logical_and(parities_before_transpose[i], parities_before_transpose[j]) + # Loop every 0 <= i < j < rank + for j in range(self.rank) + for i in range(0, j) + if after_by_before[i] > after_by_before[j]), + torch.zeros([], dtype=torch.bool)) + # parity True -> -x + # parity False -> +x + data = torch.where(parity.permute(before_by_after), -data, +data) + return Tensor( + names=names, + edges=[self.edges[index] for index in before_by_after], + fermion=self.fermion, + dtypes=self.dtypes, + data=data, + mask=mask, + ) + + def reverse_edge( + self: Tensor, + reversed_names: set[str], + apply_parity: bool = False, + parity_exclude_names: typing.Optional[set[str]] = None, + ) -> Tensor: + """ + Reverse some edge in the tensor. + + Parameters + ---------- + reversed_names : set[str] + The edge names of those edges which will be reversed + apply_parity : bool, default=False + Whether to apply parity caused by reversing edge, since reversing edge will generate half a sign. + parity_exclude_names : set[str], optional + The name of edges in the different behavior other than default set by apply_parity. + + Returns + ------- + Tensor + The tensor with edges reversed. + """ + if not any(self.fermion): + return self + if parity_exclude_names is None: + parity_exclude_names = set() + assert all(name in self.names for name in reversed_names) + assert all(name in reversed_names for name in parity_exclude_names) + data = self.data + if any(self.fermion): + # Parity is xor of all valid reverse parity + parity = functools.reduce( + torch.logical_xor, + ( + _utility.unsqueeze(edge.parity, current_index, self.rank) + # Loop over all edge + for current_index, [name, edge] in enumerate(zip(self.names, self.edges)) + # Check if this edge is reversed and if this edge will be applied parity + if (name in reversed_names) and (apply_parity ^ (name in parity_exclude_names))), + torch.zeros([], dtype=torch.bool), + ) + data = torch.where(parity, -data, +data) + return Tensor( + names=self.names, + edges=[ + Edge( + fermion=edge.fermion, + dtypes=edge.dtypes, + symmetry=edge.symmetry, + dimension=edge.dimension, + arrow=not edge.arrow if self.names[current_index] in reversed_names else edge.arrow, + parity=edge.parity, + ) for current_index, edge in enumerate(self.edges) + ], + fermion=self.fermion, + dtypes=self.dtypes, + data=data, + mask=self.mask, + ) + + @staticmethod + def _split_edge_get_name_group( + name: str, + split_map: dict[str, list[tuple[str, Edge]]], + ) -> list[str]: + split_group: typing.Optional[list[tuple[str, Edge]]] = split_map.get(name, None) + # pylint: disable=no-else-return + if split_group is None: + return [name] + else: + return [new_name for new_name, _ in split_group] + + @staticmethod + def _split_edge_get_edge_group( + name: str, + edge: Edge, + split_map: dict[str, list[tuple[str, Edge]]], + ) -> list[Edge]: + split_group: typing.Optional[list[tuple[str, Edge]]] = split_map.get(name, None) + # pylint: disable=no-else-return + if split_group is None: + return [edge] + else: + return [new_edge for _, new_edge in split_group] + + def split_edge( + self: Tensor, + split_map: dict[str, list[tuple[str, Edge]]], + apply_parity: bool = False, + parity_exclude_names: typing.Optional[set[str]] = None, + ) -> Tensor: + """ + Split some edges in this tensor. + + Parameters + ---------- + split_map : dict[str, list[tuple[str, Edge]]] + The edge splitting plan. + apply_parity : bool, default=False + Whether to apply parity caused by splitting edge, since splitting edge will generate half a sign. + parity_exclude_names : set[str], optional + The name of edges in the different behavior other than default set by apply_parity. + + Returns + ------- + Tensor + The tensor with edges splitted. + """ + if parity_exclude_names is None: + parity_exclude_names = set() + # Check the edge to be splitted can be merged by result edges. + assert all( + self.edge_by_name(name) == Edge.merge_edges( + [new_edge for _, new_edge in split_result], + fermion=self.fermion, + dtypes=self.dtypes, + arrow=self.edge_by_name(name).arrow, + ) for name, split_result in split_map.items()) + assert all(name in split_map for name in parity_exclude_names) + # Calculate the result components + names: list[str] = functools.reduce( + # Concat list + operator.add, + # If name in split_map, use the new names list, otherwise use name itself as a length-1 list + (Tensor._split_edge_get_name_group(name, split_map) for name in self.names), + # Reduce from [] to concat all list + [], + ) + edges: list[Edge] = functools.reduce( + # Concat list + operator.add, + # If name in split_map, use the new edges list, otherwise use the edge itself as a length-1 list + (Tensor._split_edge_get_edge_group(name, edge, split_map) for name, edge in zip(self.names, self.edges)), + # Reduce from [] to concat all list + [], + ) + new_size = [edge.dimension for edge in edges] + data = self.data.reshape(new_size) + mask = self.mask.reshape(new_size) + + # Apply parity + if any(self.fermion): + # It is fermionic tensor, parity need to be applied + new_rank = len(names) + # Parity is xor of all valid splitting parity + parity = functools.reduce( + torch.logical_xor, + ( + # Apply the parity for this splitting group here + # It is need to perform a total transpose on this split group + # {sum 0<=i list[str]: + reversed_names: list[str] = [] + for name in reversed(self.names): + found_in_merge_map: typing.Optional[tuple[str, list[str]]] = next( + ((new_name, old_names) for new_name, old_names in merge_map.items() if name in old_names), None) + if found_in_merge_map is None: + # This edge will not be merged + reversed_names.append(name) + else: + new_name, old_names = found_in_merge_map + # This edge will be merged + if name == old_names[-1]: + # Add new edge only if it is the last edge + reversed_names.append(new_name) + # Some edge is merged from no edges, it should be considered + for new_name, old_names in merge_map.items(): + if not old_names: + reversed_names.append(new_name) + return list(reversed(reversed_names)) + + @staticmethod + def _merge_edge_get_name_group(name: str, merge_map: dict[str, list[str]]) -> list[str]: + merge_group: typing.Optional[list[str]] = merge_map.get(name, None) + # pylint: disable=no-else-return + if merge_group is None: + return [name] + else: + return merge_group + + def merge_edge( + self: Tensor, + merge_map: dict[str, list[str]], + apply_parity: bool = False, + parity_exclude_names: typing.Optional[set[str]] = None, + *, + merge_arrow: typing.Optional[dict[str, bool]] = None, + names: typing.Optional[list[str]] = None, + ) -> Tensor: + """ + Merge some edges in this tensor. + + Parameters + ---------- + merge_map : dict[str, list[str]] + The edge merging plan. + apply_parity : bool, default=False + Whether to apply parity caused by merging edge, since merging edge will generate half a sign. + parity_exclude_names : set[str], optional + The name of edges in the different behavior other than default set by apply_parity. + merge_arrow : dict[str, bool], optional + For merging edge from zero edges, arrow cannot be identified automatically, it requires user set manually. + names : list[str], optional + The edge order of the result, sometimes user may want to specify it manually. + + Returns + ------- + Tensor + The tensor with edges merged. + """ + # pylint: disable=too-many-locals + if parity_exclude_names is None: + parity_exclude_names = set() + if merge_arrow is None: + merge_arrow = {} + assert all(all(old_name in self.names for old_name in old_names) for _, old_names in merge_map.items()) + assert all(name in merge_map for name in parity_exclude_names) + # Two steps: 1. Transpose 2. Merge + if names is None: + names = self._merge_edge_get_names(merge_map) + transposed_names: list[str] = functools.reduce( + # Concat list + operator.add, + # If name in merge_map, use the old names list, otherwise use name itself as a length-1 list + (Tensor._merge_edge_get_name_group(name, merge_map) for name in names), + # Reduce from [] to concat all list + [], + ) + transposed_tensor = self.transpose(transposed_names) + # Prepare a name to index map, since we need to look up it frequently later. + transposed_name_map = {name: index for index, name in enumerate(transposed_tensor.names)} + edges = [ + # If name is created by merging, call Edge.merge_edges to get the merged edge, otherwise get it directly + # from transposed_tensor. + Edge.merge_edges( + edges=[transposed_tensor.edges[transposed_name_map[old_name]] + for old_name in merge_map[name]], + fermion=self.fermion, + dtypes=self.dtypes, + arrow=merge_arrow.get(name, None), + # If merging edge from zero edge, arrow need to be set manually + ) if name in merge_map else transposed_tensor.edges[transposed_name_map[name]] + # Loop over names + for name in names + ] + transposed_data = transposed_tensor.data + transposed_mask = transposed_tensor.mask + + # Apply parity + if any(self.fermion): + # It is fermionic tensor, parity need to be applied + # Parity is xor of all valid merging parity + parity = functools.reduce( + torch.logical_xor, + ( + # Apply the parity for this merging group here + # It is need to perform a total transpose on this merging group + # {sum 0<=i Tensor: + """ + Contract two tensors. + + Parameters + ---------- + other : Tensor + Another tensor to be contracted. + contract_pairs : set[tuple[str, str]] + The pairs of edges to be contract between two tensors. + fuse_names : set[str], optional + The set of edges to be fuses. + + Returns + ------- + Tensor + The result contracted by two tensors. + """ + # pylint: disable=too-many-locals + # Only same type tensor can be contracted. + assert self.same_type_with(other) + + if fuse_names is None: + fuse_names = set() + # Fuse name should not have any symmetry + assert all( + all(_utility.zero_symmetry(sub_symmetry) + for sub_symmetry in self.edge_by_name(fuse_name).symmetry) + for fuse_name in fuse_names) + + # Alias tensor + tensor_1: Tensor = self + tensor_2: Tensor = other + + # Check if contract edge and fuse edge compatible + assert all(tensor_1.edge_by_name(name) == tensor_2.edge_by_name(name) for name in fuse_names) + assert all( + tensor_1.edge_by_name(name_1).conjugate() == tensor_2.edge_by_name(name_2) + for name_1, name_2 in contract_pairs) + + # All tensor edges merged to three part: fuse edge, contract edge, free edge + + # Contract of tensor has 5 step: + # 1. reverse arrow + # reverse all free edge and fuse edge to arrow False, without parity apply. + # reverse contract edge to two arrow: False(tensor_2) and True(tensor_1), only apply parity to one tensor. + # 2. merge edge + # merge all edge in the same part to one edge, only apply parity to contract edge of one tensor + # free edge and fuse edge will not be applied parity. + # 3. contract matrix + # call matrix multiply + # 4. split edge + # split edge merged in step 2, without apply parity + # 5. reverse arrow + # reverse arrow reversed in step 1, without parity apply + + # Step 1 + contract_names_1: set[str] = {name_1 for name_1, name_2 in contract_pairs} + contract_names_2: set[str] = {name_2 for name_1, name_2 in contract_pairs} + arrow_true_names_1: set[str] = {name for name, edge in zip(tensor_1.names, tensor_1.edges) if edge.arrow} + arrow_true_names_2: set[str] = {name for name, edge in zip(tensor_2.names, tensor_2.edges) if edge.arrow} + + # tensor 1: contract_names & arrow_false | not contract_names & arrow_true -> contract_names ^ arrow_true + tensor_1 = tensor_1.reverse_edge(contract_names_1 ^ arrow_true_names_1, False, + contract_names_1 - arrow_true_names_1) + tensor_2 = tensor_2.reverse_edge(arrow_true_names_2, False, set()) + + # Step 2 + free_edges_1: list[tuple[str, Edge]] = [(name, edge) + for name, edge in zip(tensor_1.names, tensor_1.edges) + if name not in fuse_names and name not in contract_names_1] + free_names_1: list[str] = [name for name, _ in free_edges_1] + free_edges_2: list[tuple[str, Edge]] = [(name, edge) + for name, edge in zip(tensor_2.names, tensor_2.edges) + if name not in fuse_names and name not in contract_names_2] + free_names_2: list[str] = [name for name, _ in free_edges_2] + # Check which tensor is bigger, and use it to determine the fuse and contract edge order. + ordered_fuse_edges: list[tuple[str, Edge]] + ordered_fuse_names: list[str] + ordered_contract_names_1: list[str] + ordered_contract_names_2: list[str] + if tensor_1.data.nelement() > tensor_2.data.nelement(): + # Tensor 1 larger, fit by tensor 1 + ordered_fuse_edges = [ + (name, edge) for name, edge in zip(tensor_1.names, tensor_1.edges) if name in fuse_names + ] + ordered_fuse_names = [name for name, _ in ordered_fuse_edges] + + # pylint: disable=unnecessary-comprehension + contract_names_map = {name_1: name_2 for name_1, name_2 in contract_pairs} + ordered_contract_names_1 = [name for name in tensor_1.names if name in contract_names_1] + ordered_contract_names_2 = [contract_names_map[name] for name in ordered_contract_names_1] + else: + # Tensor 2 larger, fit by tensor 2 + ordered_fuse_edges = [ + (name, edge) for name, edge in zip(tensor_2.names, tensor_2.edges) if name in fuse_names + ] + ordered_fuse_names = [name for name, _ in ordered_fuse_edges] + + contract_names_map = {name_2: name_1 for name_1, name_2 in contract_pairs} + ordered_contract_names_2 = [name for name in tensor_2.names if name in contract_names_2] + ordered_contract_names_1 = [contract_names_map[name] for name in ordered_contract_names_2] + + put_contract_1_right: bool = next( + (name in contract_names_1 for name in reversed(tensor_1.names) if name not in fuse_names), True) + put_contract_2_right: bool = next( + (name in contract_names_2 for name in reversed(tensor_2.names) if name not in fuse_names), False) + + tensor_1 = tensor_1.merge_edge( + { + "Free_1": free_names_1, + "Contract_1": ordered_contract_names_1, + "Fuse_1": ordered_fuse_names, + }, + False, + {"Contract_1"}, + merge_arrow={ + "Free_1": False, + "Contract_1": True, + "Fuse_1": False, + }, + names=["Fuse_1", "Free_1", "Contract_1"] if put_contract_1_right else ["Fuse_1", "Contract_1", "Free_1"], + ) + tensor_2 = tensor_2.merge_edge( + { + "Free_2": free_names_2, + "Contract_2": ordered_contract_names_2, + "Fuse_2": ordered_fuse_names, + }, + False, + set(), + merge_arrow={ + "Free_2": False, + "Contract_2": False, + "Fuse_2": False, + }, + names=["Fuse_2", "Free_2", "Contract_2"] if put_contract_2_right else ["Fuse_2", "Contract_2", "Free_2"], + ) + # C[fuse, free1, free2] = A[fuse, free1 contract] B[fuse, contract free2] + assert tensor_1.edge_by_name("Fuse_1") == tensor_2.edge_by_name("Fuse_2") + assert tensor_1.edge_by_name("Contract_1").conjugate() == tensor_2.edge_by_name("Contract_2") + + # Step 3 + # The standard arrow is + # (0, False, True) (0, False, False) + # aka: (a b) (c d) (c+ b+) = (a d) + # since: EPR pair order is (False True) + # put_contract_1_right should be True + # put_contract_2_right should be False + # Every mismatch generate a sign + # Total sign is + # (!put_contract_1_right) ^ (put_contract_2_right) = put_contract_1_right == put_contract_2_right + dtype = torch.result_type(tensor_1.data, tensor_2.data) + data = torch.einsum( + "b" + ("ic" if put_contract_1_right else "ci") + ",b" + ("jc" if put_contract_2_right else "cj") + "->bij", + tensor_1.data.to(dtype=dtype), tensor_2.data.to(dtype=dtype)) + if put_contract_1_right == put_contract_2_right: + data = torch.where(tensor_2.edge_by_name("Free_2").parity.reshape([1, 1, -1]), -data, +data) + tensor = Tensor( + names=["Fuse", "Free_1", "Free_2"], + edges=[tensor_1.edge_by_name("Fuse_1"), + tensor_1.edge_by_name("Free_1"), + tensor_2.edge_by_name("Free_2")], + fermion=self.fermion, + dtypes=self.dtypes, + data=data, + ) + + # Step 4 + tensor = tensor.split_edge({ + "Fuse": ordered_fuse_edges, + "Free_1": free_edges_1, + "Free_2": free_edges_2 + }, False, set()) + + # Step 5 + tensor = tensor.reverse_edge( + (arrow_true_names_1 - contract_names_1) | (arrow_true_names_2 - contract_names_2), + False, + set(), + ) + + return tensor + + def _trace_group_edge( + self: Tensor, + trace_pairs: set[tuple[str, str]], + fuse_names: dict[str, tuple[str, str]], + ) -> tuple[ + list[str], + list[str], + list[str], + list[str], + list[str], + list[str], + list[int], + list[int], + ]: + # pylint: disable=too-many-locals + # pylint: disable=unnecessary-comprehension + trace_map = { + old_name_1: old_name_2 for old_name_1, old_name_2 in trace_pairs + } | { + old_name_2: old_name_1 for old_name_1, old_name_2 in trace_pairs + } + fuse_map = { + old_name_1: (old_name_2, new_name) for new_name, [old_name_1, old_name_2] in fuse_names.items() + } | { + old_name_2: (old_name_1, new_name) for new_name, [old_name_1, old_name_2] in fuse_names.items() + } + reversed_trace_names_1: list[str] = [] + reversed_trace_names_2: list[str] = [] + reversed_fuse_names_1: list[str] = [] + reversed_fuse_names_2: list[str] = [] + reversed_free_names: list[str] = [] + reversed_fuse_names_result: list[str] = [] + reversed_free_index: list[int] = [] + reversed_fuse_index_result: list[int] = [] + added_names: set[str] = set() + for index, name in zip(reversed(range(self.rank)), reversed(self.names)): + if name not in added_names: + trace_name: typing.Optional[str] = trace_map.get(name, None) + fuse_name: typing.Optional[tuple[str, str]] = fuse_map.get(name, None) + if trace_name is not None: + reversed_trace_names_2.append(name) + reversed_trace_names_1.append(trace_name) + added_names.add(trace_name) + elif fuse_name is not None: + # fuse_name = another old name, new name + reversed_fuse_names_2.append(name) + reversed_fuse_names_1.append(fuse_name[0]) + added_names.add(fuse_name[0]) + reversed_fuse_names_result.append(fuse_name[1]) + reversed_fuse_index_result.append(index) + else: + reversed_free_names.append(name) + reversed_free_index.append(index) + return ( + list(reversed(reversed_trace_names_1)), + list(reversed(reversed_trace_names_2)), + list(reversed(reversed_fuse_names_1)), + list(reversed(reversed_fuse_names_2)), + list(reversed(reversed_free_names)), + list(reversed(reversed_fuse_names_result)), + list(reversed(reversed_free_index)), + list(reversed(reversed_fuse_index_result)), + ) + + def trace( + self: Tensor, + trace_pairs: set[tuple[str, str]], + fuse_names: typing.Optional[dict[str, tuple[str, str]]] = None, + ) -> Tensor: + """ + Trace a tensor. + + Parameters + ---------- + trace_pairs : set[tuple[str, str]] + The pairs of edges to be traced + fuse_names : dict[str, tuple[str, str]] + The edges to be fused. + + Returns + ------- + Tensor + The traced tensor. + """ + # pylint: disable=too-many-locals + if fuse_names is None: + fuse_names = {} + # Fuse names should not have any symmetry + assert all( + all(_utility.zero_symmetry(sub_symmetry) + for sub_symmetry in self.edge_by_name(old_name_1).symmetry) + for new_name, [old_name_1, old_name_2] in fuse_names.items()) + # Fuse names should share the same edges + assert all( + self.edge_by_name(old_name_1) == self.edge_by_name(old_name_2) + for new_name, [old_name_1, old_name_2] in fuse_names.items()) + # Trace edges should be compatible + assert all( + self.edge_by_name(old_name_1).conjugate() == self.edge_by_name(old_name_2) + for old_name_1, old_name_2 in trace_pairs) + + # Split trace pairs and fuse names to two part before main part of trace. + [ + trace_names_1, + trace_names_2, + fuse_names_1, + fuse_names_2, + free_names, + fuse_names_result, + free_index, + fuse_index_result, + ] = self._trace_group_edge(trace_pairs, fuse_names) + + # Make alias + tensor = self + + # Tensor edges merged to 5 parts: fuse edge 1, fuse edge 2, trace edge 1, trace edge 2, free edge + # Trace contains 5 step: + # 1. reverse all arrow to False except trace edge 1 to be True, only apply parity to one of two trace edge + # 2. merge all edge to 5 part, only apply parity to one of two trace edge + # 3. trace trivial tensor + # 4. split edge merged in step 2, without apply parity + # 5. reverse arrow reversed in step 1, without apply parity + + # Step 1 + arrow_true_names = {name for name, edge in zip(tensor.names, tensor.edges) if edge.arrow} + unordered_trace_names_1 = set(trace_names_1) + tensor = tensor.reverse_edge(unordered_trace_names_1 ^ arrow_true_names, False, + unordered_trace_names_1 - arrow_true_names) + + # Step 2 + free_edges: list[tuple[str, + Edge]] = [(name, tensor.edges[index]) for name, index in zip(free_names, free_index)] + fuse_edges_result: list[tuple[str, Edge]] = [ + (name, tensor.edges[index]) for name, index in zip(fuse_names_result, fuse_index_result) + ] + tensor = tensor.merge_edge( + { + "Trace_1": trace_names_1, + "Trace_2": trace_names_2, + "Fuse_1": fuse_names_1, + "Fuse_2": fuse_names_2, + "Free": free_names, + }, + False, + {"Trace_1"}, + merge_arrow={ + "Trace_1": True, + "Trace_2": False, + "Fuse_1": False, + "Fuse_2": False, + "Free": False, + }, + names=["Trace_1", "Trace_2", "Fuse_1", "Fuse_2", "Free"], + ) + # B[fuse, free] = A[trace, trace, fuse, fuse, free] + assert tensor.edges[2] == tensor.edges[3] + assert tensor.edges[0].conjugate() == tensor.edges[1] + + # Step 3 + # As tested, the order of edges in this einsum is not important + # ttffi->fi, fftti->fi, ffitt->fi, ttiff->if, ittff->if, ifftt->if + data = torch.einsum("ttffi->fi", tensor.data) + tensor = Tensor( + names=["Fuse", "Free"], + edges=[tensor.edges[2], tensor.edges[4]], + fermion=self.fermion, + dtypes=self.dtypes, + data=data, + ) + + # Step 4 + tensor = tensor.split_edge({ + "Fuse": fuse_edges_result, + "Free": free_edges, + }, False, set()) + + # Step 5 + tensor = tensor.reverse_edge( + # Free edge with arrow true + {name for name in free_names if name in arrow_true_names} | + # New edge from fused edge with arrow true + {new_name for old_name, new_name in zip(fuse_names_1, fuse_names_result) if old_name in arrow_true_names}, + False, + set(), + ) + + return tensor + + def conjugate(self: Tensor, trivial_metric: bool = False) -> Tensor: + """ + Get the conjugate of this tensor. + + Parameters + ---------- + trivial_metric : bool, default=False + Fermionic tensor in network may result in non positive definite metric, set this trivial_metric to True to + ensure the metric to be positive, but it breaks the associative law with tensor contract. + + Returns + ------- + Tensor + The conjugated tensor. + """ + data = torch.conj(self.data) + + # Usually, only a full transpose sign is applied. + # If trivial_metric is set True, parity in edges with arrow True is also applied. + + # Apply parity + if any(self.fermion): + # It is fermionic tensor, parity need to be applied + + # Parity is parity generated from a full transpose + # {sum 0<=i Edge: + # Used in matrix decomposition: SVD and QR + # It relies on decomposition of block tensor is also block tensor. + # Otherwise it cannot guess the correct edge + # QR + # Full rank case: + # QR has uniqueness with a diagonal unitary matrix freedom for full rank case, + # While diagonal unitary does not change the block condition. Since we know there is at least a decomposition + # result which is block matrix, we know all possible decomposition is blocked. + # Proof: + # shape of A is m * n + # if m >= n: + # A = [Q1 U1] [[R1] [0]] = [Q2 U2] [[R2] [0]] + # A is full rank => R1 and R2 are invertible + # Q1 R1 = Q2 R2 and (R1 R2 invertible) => Q2^dagger Q1 = R2 R1^-1, Q1^dagger Q2 = R1 R2^-1 + # lemma: product of inverse of upper triangular matrix is also upper triangular. + # Q2^dagger Q1, Q1^dagger Q2 are upper triangular => Q2^dagger Q1 is upper triangular and lower triangular. + # => Q2^dagger Q1 is diagonal => Q2^dagger Q1 = R2 R1^-1 = S, where S is diagonal matrix. + # => Q1 = Q1 R1 R1^-1 = Q2 R2 R1^-1 = Q2 S => Q1 = Q2 S => S is diagonal unitary. + # At last, we have Q1 = Q2 S where S is a diagonal unitary matrix while S R1 = R2 + # if m < n: + # A = Q1 [R1 N1] = Q2 [R2 N2], so we have Q1 R1 = Q2 R2 + # This is the case for m = n, so Q1 = Q2 S, S R1 = R2. + # At last, Q1 N1 = Q2 S N1 = Q2 N2 implies S N1 = N2. + # Where S is diagonal unitary. + # Rank sufficient case: + # It is hard to get the conclusion. Program may break at this situation. + # SVD + # For non-singular case + # SVD has uniqueness with a blocked unitary matrix freedom, which preserves the singular value subspace. + # So edge guessing fails iff there is the same singular value crossing different quantum number. + # In this case, program may break. + # Proof: + # Let m <= n, since it is symmetric on the dimension. + # A = U1 S1 V1 => U2 S2 V2 => A A^dagger = U1 S1^2 U1^dagger = U2 S2^2 dagger U2 + # The eigenvalue is unique in descending order, while singular value is non-negative real number. + # => S1 = S2 = S, and for eigenvector, U1 = U2 Q where Q is a unitary matrix that [Q S] = 0 + # => U1 S V1 = U2 S V2 = U2 Q S V1 = U2 S Q V2 => S Q V2 = S V1, while S is non-singular, so Q V2 = V1. + # At last, U1 = U2 Q, S1 = S2, Q V1 = V2. + # For singular case + # It is not determined for singular part of unitary. It is similar to the non-similar case. + # But at last step, S Q V2 = S V1 => Q' V2 = V1, where Q' is the same to Q only in non-singular part. + # So, it does break blocks only if blocks has been broken by the same singular value. + # pylint: disable=invalid-name + m, n = matrix.size() + assert edge.dimension == m + argmax = torch.argmax(matrix, dim=0) + assert argmax.size() == (n,) + return Edge( + fermion=edge.fermion, + dtypes=edge.dtypes, + symmetry=[_utility.neg_symmetry(sub_symmetry[argmax]) for sub_symmetry in edge.symmetry], + dimension=n, + arrow=arrow, + parity=edge.parity[argmax], + ) + + def _ensure_mask(self: Tensor) -> None: + """ + Currently this function is only called from SVD decomposition. It ensure that element at mask False is very + small and set them exactly zero. + + Any function other than SVD and QR would not break blocked tensor, while QR is implemented by givens rotation + which preserve the blocks, so there is not need to ensure mask there. + """ + assert torch.allclose(torch.where(self.mask, torch.zeros([], dtype=self.dtype), self.data), + torch.zeros([], dtype=self.dtype)) + self._data = torch.where(self.mask, self.data, torch.zeros([], dtype=self.dtype)) + + def svd( + self: Tensor, + free_names_u: set[str], + common_name_u: str, + common_name_v: str, + singular_name_u: str, + singular_name_v: str, + cut: int = -1, + ) -> tuple[Tensor, Tensor, Tensor]: + """ + SVD decomposition a tensor. Because of the edge created by SVD is guessed based on the SVD result, the program + may break if there is repeated singular value which may result in non-blocked composition result. + + Parameters + ---------- + free_names_u : set[str] + Free names in U tensor of the result of SVD. + common_name_u, common_name_v, singular_name_u, singular_name_v : str + The name of generated edges. + cut : int, default=-1 + The cut for the singular values. + + Returns + ------- + tuple[Tensor, Tensor, Tensor] + U, S, V tensor, the result of SVD. + """ + # pylint: disable=too-many-arguments + # pylint: disable=too-many-locals + + free_names_v = {name for name in self.names if name not in free_names_u} + + assert all(name in self.names for name in free_names_u) + assert common_name_u not in free_names_u + assert common_name_v not in free_names_v + + arrow_true_names = {name for name, edge in zip(self.names, self.edges) if edge.arrow} + + tensor = self.reverse_edge(arrow_true_names, False, set()) + + ordered_free_edges_u: list[tuple[str, Edge]] = [ + (name, edge) for name, edge in zip(tensor.names, tensor.edges) if name in free_names_u + ] + ordered_free_edges_v: list[tuple[str, Edge]] = [ + (name, edge) for name, edge in zip(tensor.names, tensor.edges) if name in free_names_v + ] + ordered_free_names_u: list[str] = [name for name, _ in ordered_free_edges_u] + ordered_free_names_v: list[str] = [name for name, _ in ordered_free_edges_v] + + put_v_right = next((name in free_names_v for name in reversed(tensor.names)), True) + tensor = tensor.merge_edge( + { + "SVD_U": ordered_free_names_u, + "SVD_V": ordered_free_names_v + }, + False, + set(), + merge_arrow={ + "SVD_U": False, + "SVD_V": False + }, + names=["SVD_U", "SVD_V"] if put_v_right else ["SVD_V", "SVD_U"], + ) + + # if self.fermion: + # data_1, data_s, data_2 = manual_svd(tensor.data, 1e-6) + # else: + # data_1, data_s, data_2 = torch.linalg.svd(tensor.data, full_matrices=False) + data_1, data_s, data_2 = torch.linalg.svd(tensor.data, full_matrices=False) + + if cut != -1: + data_1 = data_1[:, :cut] + data_s = data_s[:cut] + data_2 = data_2[:cut, :] + data_s = torch.diag_embed(data_s) + + free_edge_1 = tensor.edges[0] + common_edge_1 = Tensor._guess_edge(torch.abs(data_1), free_edge_1, True) + tensor_1 = Tensor( + names=["SVD_U", common_name_u] if put_v_right else ["SVD_V", common_name_v], + edges=[free_edge_1, common_edge_1], + fermion=self.fermion, + dtypes=self.dtypes, + data=data_1, + ) + tensor_1._ensure_mask() # pylint: disable=protected-access + free_edge_2 = tensor.edges[1] + common_edge_2 = Tensor._guess_edge(torch.abs(data_2).transpose(0, 1), free_edge_2, False) + tensor_2 = Tensor( + names=[common_name_v, "SVD_V"] if put_v_right else [common_name_u, "SVD_U"], + edges=[common_edge_2, free_edge_2], + fermion=self.fermion, + dtypes=self.dtypes, + data=data_2, + ) + tensor_2._ensure_mask() # pylint: disable=protected-access + assert common_edge_1.conjugate() == common_edge_2 + tensor_s = Tensor( + names=[singular_name_u, singular_name_v] if put_v_right else [singular_name_v, singular_name_u], + edges=[common_edge_2, common_edge_1], + fermion=self.fermion, + dtypes=self.dtypes, + data=data_s, + ) + + tensor_u = tensor_1 if put_v_right else tensor_2 + tensor_v = tensor_2 if put_v_right else tensor_1 + + tensor_u = tensor_u.split_edge({"SVD_U": ordered_free_edges_u}, False, set()) + tensor_v = tensor_v.split_edge({"SVD_V": ordered_free_edges_v}, False, set()) + + tensor_u = tensor_u.reverse_edge(arrow_true_names & free_names_u, False, set()) + tensor_v = tensor_v.reverse_edge(arrow_true_names & free_names_v, False, set()) + + return tensor_u, tensor_s, tensor_v + + def qr( + self: Tensor, + free_names_direction: str, + free_names: set[str], + common_name_q: str, + common_name_r: str, + ) -> tuple[Tensor, Tensor]: + """ + QR decomposition on this tensor. Because of the edge created by QR is guessed based on the QR result, the + program may break if the tensor is rank deficient which may result in non-blocked composition result. + + Parameters + ---------- + free_names_direction : 'Q' | 'q' | 'R' | 'r' + Specify which direction the free_names will set + free_names : set[str] + The names of free edges after QR decomposition. + common_name_q, common_name_r : str + The names of edges created by QR decomposition. + + Returns + ------- + tuple[Tensor, Tensor] + Tensor Q and R, the result of QR decomposition. + """ + # pylint: disable=invalid-name + # pylint: disable=too-many-locals + + if free_names_direction in {'Q', 'q'}: + free_names_q = free_names + free_names_r = {name for name in self.names if name not in free_names} + elif free_names_direction in {'R', 'r'}: + free_names_r = free_names + free_names_q = {name for name in self.names if name not in free_names} + + assert all(name in self.names for name in free_names) + assert common_name_q not in free_names_q + assert common_name_r not in free_names_r + + arrow_true_names = {name for name, edge in zip(self.names, self.edges) if edge.arrow} + + tensor = self.reverse_edge(arrow_true_names, False, set()) + + ordered_free_edges_q: list[tuple[str, Edge]] = [ + (name, edge) for name, edge in zip(tensor.names, tensor.edges) if name in free_names_q + ] + ordered_free_edges_r: list[tuple[str, Edge]] = [ + (name, edge) for name, edge in zip(tensor.names, tensor.edges) if name in free_names_r + ] + ordered_free_names_q: list[str] = [name for name, _ in ordered_free_edges_q] + ordered_free_names_r: list[str] = [name for name, _ in ordered_free_edges_r] + + # pytorch does not provide LQ, so always put r right here + tensor = tensor.merge_edge( + { + "QR_Q": ordered_free_names_q, + "QR_R": ordered_free_names_r + }, + False, + set(), + merge_arrow={ + "QR_Q": False, + "QR_R": False + }, + names=["QR_Q", "QR_R"], + ) + + # if self.fermion: + # # Blocked tensor, use Givens rotation + # data_q, data_r = givens_qr(tensor.data) + # else: + # # Non-blocked tensor, use Householder reflection + # data_q, data_r = torch.linalg.qr(tensor.data, mode="reduced") + data_q, data_r = torch.linalg.qr(tensor.data, mode="reduced") + + free_edge_q = tensor.edges[0] + common_edge_q = Tensor._guess_edge(torch.abs(data_q), free_edge_q, True) + tensor_q = Tensor( + names=["QR_Q", common_name_q], + edges=[free_edge_q, common_edge_q], + fermion=self.fermion, + dtypes=self.dtypes, + data=data_q, + ) + tensor_q._ensure_mask() # pylint: disable=protected-access + free_edge_r = tensor.edges[1] + # common_edge_r = Tensor._guess_edge(torch.abs(data_r).transpose(0, 1), free_edge_r, False) + # Sometimes R matrix maybe singular, guess edge will return arbitary symmetry, so do not use guessed edge. + common_edge_r = common_edge_q.conjugate() + tensor_r = Tensor( + names=[common_name_r, "QR_R"], + edges=[common_edge_r, free_edge_r], + fermion=self.fermion, + dtypes=self.dtypes, + data=data_r, + ) + tensor_r._ensure_mask() # pylint: disable=protected-access + assert common_edge_q.conjugate() == common_edge_r + + tensor_q = tensor_q.split_edge({"QR_Q": ordered_free_edges_q}, False, set()) + tensor_r = tensor_r.split_edge({"QR_R": ordered_free_edges_r}, False, set()) + + tensor_q = tensor_q.reverse_edge(arrow_true_names & free_names_q, False, set()) + tensor_r = tensor_r.reverse_edge(arrow_true_names & free_names_r, False, set()) + + return tensor_q, tensor_r + + def identity(self: Tensor, pairs: set[tuple[str, str]]) -> Tensor: + """ + Get the identity tensor with same shape to this tensor. + + Parameters + ---------- + pairs : set[tuple[str, str]] + The pair of edge names to specify the relation among edges to set identity tensor. + + Returns + ------- + Tensor + The result identity tensor. + """ + # The order of edges before setting identity should be (False True) + # Merge tensor directly to two edge, set identity and split it directly. + # When splitting, only apply parity to one part of edges + + # pylint: disable=unnecessary-comprehension + pairs_map = {name_1: name_2 for name_1, name_2 in pairs} | {name_2: name_1 for name_1, name_2 in pairs} + added_names: set[str] = set() + reversed_names_1: list[str] = [] + reversed_names_2: list[str] = [] + for name in reversed(self.names): + if name not in added_names: + another_name = pairs_map[name] + reversed_names_2.append(name) + reversed_names_1.append(another_name) + added_names.add(another_name) + names_1 = list(reversed(reversed_names_1)) + names_2 = list(reversed(reversed_names_2)) + # unordered_names_1 = set(names_1) + unordered_names_2 = set(names_2) + + arrow_true_names = {name for name, edge in zip(self.names, self.edges) if edge.arrow} + + # Two edges, arrow of two edges are (False, True) + tensor = self.reverse_edge(unordered_names_2 ^ arrow_true_names, False, unordered_names_2 - arrow_true_names) + + edges_1 = [(name, tensor.edge_by_name(name)) for name in names_1] + edges_2 = [(name, tensor.edge_by_name(name)) for name in names_2] + + tensor = tensor.merge_edge( + { + "Identity_1": names_1, + "Identity_2": names_2 + }, + False, + {"Identity_2"}, + merge_arrow={ + "Identity_1": False, + "Identity_2": True + }, + names=["Identity_1", "Identity_2"], + ) + + tensor = Tensor( + names=tensor.names, + edges=tensor.edges, + fermion=tensor.fermion, + dtypes=tensor.dtypes, + data=torch.eye(*tensor.data.size()), + mask=tensor.mask, + ) + + tensor = tensor.split_edge({"Identity_1": edges_1, "Identity_2": edges_2}, False, {"Identity_2"}) + + tensor = tensor.reverse_edge(unordered_names_2 ^ arrow_true_names, False, unordered_names_2 - arrow_true_names) + + return tensor + + def exponential(self: Tensor, pairs: set[tuple[str, str]]) -> Tensor: + """ + Get the exponential tensor of this tensor. + + Parameters + ---------- + pairs : set[tuple[str, str]] + The pair of edge names to specify the relation among edges to calculate exponential tensor. + + Returns + ------- + Tensor + The result exponential tensor. + """ + # The order of edges before setting exponential should be (False True) + # Merge tensor directly to two edge, set exponential and split it directly. + # When splitting, only apply parity to one part of edges + + unordered_names_1 = {name_1 for name_1, name_2 in pairs} + unordered_names_2 = {name_2 for name_1, name_2 in pairs} + if self.names and self.names[-1] in unordered_names_1: + unordered_names_1, unordered_names_2 = unordered_names_2, unordered_names_1 + # pylint: disable=unnecessary-comprehension + pairs_map = {name_1: name_2 for name_1, name_2 in pairs} | {name_2: name_1 for name_1, name_2 in pairs} + reversed_names_1: list[str] = [] + reversed_names_2: list[str] = [] + for name in reversed(self.names): + if name in unordered_names_2: + another_name = pairs_map[name] + reversed_names_2.append(name) + reversed_names_1.append(another_name) + names_1 = list(reversed(reversed_names_1)) + names_2 = list(reversed(reversed_names_2)) + + arrow_true_names = {name for name, edge in zip(self.names, self.edges) if edge.arrow} + + # Two edges, arrow of two edges are (False, True) + tensor = self.reverse_edge(unordered_names_2 ^ arrow_true_names, False, unordered_names_2 - arrow_true_names) + + edges_1 = [(name, tensor.edge_by_name(name)) for name in names_1] + edges_2 = [(name, tensor.edge_by_name(name)) for name in names_2] + + tensor = tensor.merge_edge( + { + "Exponential_1": names_1, + "Exponential_2": names_2 + }, + False, + {"Exponential_2"}, + merge_arrow={ + "Exponential_1": False, + "Exponential_2": True + }, + names=["Exponential_1", "Exponential_2"], + ) + + tensor = Tensor( + names=tensor.names, + edges=tensor.edges, + fermion=tensor.fermion, + dtypes=tensor.dtypes, + data=torch.linalg.matrix_exp(tensor.data), + mask=tensor.mask, + ) + + tensor = tensor.split_edge({"Exponential_1": edges_1, "Exponential_2": edges_2}, False, {"Exponential_2"}) + + tensor = tensor.reverse_edge(unordered_names_2 ^ arrow_true_names, False, unordered_names_2 - arrow_true_names) + + return tensor diff --git a/tests/test_compat.py b/tests/test_compat.py new file mode 100644 index 000000000..824610cf6 --- /dev/null +++ b/tests/test_compat.py @@ -0,0 +1,124 @@ +"Test compat" + +import torch +import tat +from tat import compat as TAT + +# pylint: disable=missing-function-docstring +# pylint: disable=invalid-name +# pylint: disable=singleton-comparison + +# It is strange, but pylint complains function args too many. So add it here +# pylint: disable=too-many-function-args + + +def test_edge_from_dimension() -> None: + assert TAT.No.Edge(4) == tat.Edge(dimension=4) + assert TAT.Fermi.Edge(4) == tat.Edge(fermion=[True], + symmetry=[torch.tensor([0, 0, 0, 0], dtype=torch.int)], + arrow=False) + assert TAT.Z2.Edge(4) == tat.Edge(symmetry=[torch.tensor([False, False, False, False])]) + + +def test_edge_from_segments() -> None: + assert TAT.Z2.Edge([ + (False, 2), + (True, 3), + ]) == tat.Edge(symmetry=[torch.tensor([False, False, True, True, True])]) + assert TAT.Fermi.Edge([ + (-1, 1), + (0, 2), + (+1, 3), + ], True) == tat.Edge( + symmetry=[torch.tensor([-1, 0, 0, +1, +1, +1], dtype=torch.int)], + arrow=True, + fermion=[True], + ) + assert TAT.FermiFermi.Edge([ + ((-1, -2), 1), + ((0, +1), 2), + ((+1, 0), 3), + ], True) == tat.Edge( + symmetry=[ + torch.tensor([-1, 0, 0, +1, +1, +1], dtype=torch.int), + torch.tensor([-2, +1, +1, 0, 0, 0], dtype=torch.int), + ], + arrow=True, + fermion=[True, True], + ) + + +def test_edge_from_segments_without_dimension() -> None: + assert TAT.Z2.Edge([False, True]) == tat.Edge(symmetry=[torch.tensor([False, True])]) + assert TAT.Fermi.Edge([-1, 0, +1], True) == tat.Edge( + symmetry=[torch.tensor([-1, 0, +1], dtype=torch.int)], + arrow=True, + fermion=[True], + ) + assert TAT.FermiFermi.Edge([ + (-1, -2), + (0, +1), + (+1, 0), + ], True) == tat.Edge( + symmetry=[torch.tensor([-1, 0, +1], dtype=torch.int), + torch.tensor([-2, +1, 0], dtype=torch.int)], + arrow=True, + fermion=[True, True], + ) + + +def test_edge_from_tuple() -> None: + assert TAT.FermiFermi.Edge(([ + ((-1, -2), 1), + ((0, +1), 2), + ((+1, 0), 3), + ], True)) == tat.Edge( + symmetry=[ + torch.tensor([-1, 0, 0, +1, +1, +1], dtype=torch.int), + torch.tensor([-2, +1, +1, 0, 0, 0], dtype=torch.int), + ], + arrow=True, + fermion=[True, True], + ) + assert TAT.FermiFermi.Edge(([ + (-1, -2), + (0, +1), + (+1, 0), + ], True)) == tat.Edge( + symmetry=[torch.tensor([-1, 0, +1], dtype=torch.int), + torch.tensor([-2, +1, 0], dtype=torch.int)], + arrow=True, + fermion=[True, True], + ) + + +def test_tensor() -> None: + a = TAT.FermiZ2.D.Tensor(["i", "j"], [ + [(-1, False), (-1, True), (0, True), (0, False)], + [(+1, True), (+1, False), (0, False), (0, True)], + ]) + b = tat.Tensor( + [ + "i", + "j", + ], + [ + tat.Edge( + fermion=[True, False], + symmetry=[ + torch.tensor([-1, -1, 0, 0], dtype=torch.int), + torch.tensor([False, True, True, False]), + ], + arrow=False, + ), + tat.Edge( + fermion=[True, False], + symmetry=[ + torch.tensor([+1, +1, 0, 0], dtype=torch.int), + torch.tensor([True, False, False, True]), + ], + arrow=False, + ), + ], + ) + assert a.same_shape_with(b, allow_transpose=False) diff --git a/tests/test_create_tensor.py b/tests/test_create_tensor.py new file mode 100644 index 000000000..5335bc992 --- /dev/null +++ b/tests/test_create_tensor.py @@ -0,0 +1,103 @@ +"Test create tensor" + +import torch +import tat + +# pylint: disable=missing-function-docstring +# pylint: disable=invalid-name +# pylint: disable=singleton-comparison + + +def test_create_tensor() -> None: + a = tat.Tensor( + [ + "i", + "j", + ], + [ + tat.Edge(symmetry=[torch.tensor([False, False, True])], fermion=[True], arrow=True), + tat.Edge(symmetry=[torch.tensor([False, False, False, True, True])], fermion=[True], arrow=False), + ], + ) + assert a.rank == 2 + assert a.names == ["i", "j"] + assert a.edges[0] == tat.Edge(symmetry=[torch.tensor([False, False, True])], fermion=[True], arrow=True) + assert a.edges[1] == tat.Edge(symmetry=[torch.tensor([False, False, False, True, True])], + fermion=[True], + arrow=False) + assert a.edges[0] == a.edge_by_name("i") + assert a.edges[1] == a.edge_by_name("j") + + +def test_tensor_get_set_item() -> None: + a = tat.Tensor( + [ + "i", + "j", + ], + [ + tat.Edge(symmetry=[torch.tensor([False, False, True])], fermion=[True], arrow=True), + tat.Edge(symmetry=[torch.tensor([False, False, False, True, True])], fermion=[True], arrow=False), + ], + ) + a[{"i": 0, "j": 0}] = 1 + assert a[0, 0] == 1 + a["i":2, "j":3] = 2 # type: ignore[misc] + assert a[{"i": 2, "j": 3}] == 2 + try: + a[2, 0] = 3 + assert False + except IndexError: + pass + assert a["i":2, "j":0] == 0 # type: ignore[misc] + + b = tat.Tensor( + [ + "i", + "j", + ], + [ + tat.Edge(symmetry=[torch.tensor([0, 0, -1])], fermion=[False]), + tat.Edge(symmetry=[torch.tensor([0, 0, 0, +1, +1])], fermion=[False]), + ], + ) + b[{"i": 0, "j": 0}] = 1 + assert b[0, 0] == 1 + b["i":2, "j":3] = 2 # type: ignore[misc] + assert b[{"i": 2, "j": 3}] == 2 + try: + b[2, 0] = 3 + assert False + except IndexError: + pass + assert b["i":2, "j":0] == 0 # type: ignore[misc] + + +def test_create_randn_tensor() -> None: + a = tat.Tensor( + ["i", "j"], + [ + tat.Edge(symmetry=[torch.tensor([False, True])]), + tat.Edge(symmetry=[torch.tensor([False, True])]), + ], + dtype=torch.float16, + ).randn_() + assert a.dtype == torch.float16 + assert a[0, 0] != 0 + assert a[1, 1] != 0 + assert a[0, 1] == 0 + assert a[1, 0] == 0 + + b = tat.Tensor( + ["i", "j"], + [ + tat.Edge(symmetry=[torch.tensor([False, False]), torch.tensor([0, -1])]), + tat.Edge(symmetry=[torch.tensor([False, False]), torch.tensor([0, +1])]), + ], + dtype=torch.float16, + ).randn_() + assert b.dtype == torch.float16 + assert b[0, 0] != 0 + assert b[1, 1] != 0 + assert b[0, 1] == 0 + assert b[1, 0] == 0 diff --git a/tests/test_edge.py b/tests/test_edge.py new file mode 100644 index 000000000..d6c6eaa58 --- /dev/null +++ b/tests/test_edge.py @@ -0,0 +1,39 @@ +"Test edge" + +import torch +from tat import Edge + +# pylint: disable=missing-function-docstring +# pylint: disable=invalid-name +# pylint: disable=singleton-comparison + + +def test_create_edge_and_basic() -> None: + a = Edge(dimension=5) + assert a.arrow == False + assert a.dimension == 5 + b = Edge(symmetry=[torch.tensor([False, False, True, True])]) + assert b.arrow == False + assert b.dimension == 4 + c = Edge(fermion=[False, True], symmetry=[torch.tensor([False, True]), torch.tensor([False, True])], arrow=True) + assert c.arrow == True + assert c.dimension == 2 + + +def test_edge_conjugate_and_equal() -> None: + a = Edge(fermion=[False, True], symmetry=[torch.tensor([False, True]), torch.tensor([0, 1])], arrow=True) + b = Edge(fermion=[False, True], symmetry=[torch.tensor([False, True]), torch.tensor([0, -1])], arrow=False) + assert a.conjugate() == b + assert a != 2 + + +def test_repr() -> None: + a = Edge(fermion=[False, True], symmetry=[torch.tensor([False, True]), torch.tensor([0, 1])], arrow=True) + repr_a = "Edge(dimension=2, arrow=True, fermion=(False,True), symmetry=([False,True],[0,1]))" + assert repr_a == repr(a) + b = Edge(symmetry=[torch.tensor([False, True]), torch.tensor([0, 1])]) + repr_b = "Edge(dimension=2, symmetry=([False,True],[0,1]))" + assert repr_b == repr(b) + c = Edge(dimension=4) + repr_c = "Edge(dimension=4)" + assert repr_c == repr(c) diff --git a/tests/test_qr.py b/tests/test_qr.py new file mode 100644 index 000000000..5085cc261 --- /dev/null +++ b/tests/test_qr.py @@ -0,0 +1,59 @@ +"Test QR" + +import torch +from tat._qr import givens_qr, householder_qr + +# pylint: disable=missing-function-docstring +# pylint: disable=invalid-name + + +def check_givens(A: torch.Tensor) -> None: + m, n = A.size() + Q, R = givens_qr(A) + assert torch.allclose(A, Q @ R) + assert torch.allclose(Q.H @ Q, torch.eye(min(m, n), dtype=A.dtype, device=A.device)) + grad_check = torch.autograd.gradcheck( + givens_qr, + A, + eps=1e-8, + atol=1e-4, + ) + assert grad_check + + +def test_qr_real_givens() -> None: + check_givens(torch.randn(7, 5, dtype=torch.float64, requires_grad=True)) + check_givens(torch.randn(5, 5, dtype=torch.float64, requires_grad=True)) + check_givens(torch.randn(5, 7, dtype=torch.float64, requires_grad=True)) + + +def test_qr_complex_givens() -> None: + check_givens(torch.randn(7, 5, dtype=torch.complex128, requires_grad=True)) + check_givens(torch.randn(5, 5, dtype=torch.complex128, requires_grad=True)) + check_givens(torch.randn(5, 7, dtype=torch.complex128, requires_grad=True)) + + +def check_householder(A: torch.Tensor) -> None: + m, n = A.size() + Q, R = householder_qr(A) + assert torch.allclose(A, Q @ R) + assert torch.allclose(Q.H @ Q, torch.eye(min(m, n), dtype=A.dtype, device=A.device)) + grad_check = torch.autograd.gradcheck( + householder_qr, + A, + eps=1e-8, + atol=1e-4, + ) + assert grad_check + + +def test_qr_real_householder() -> None: + check_householder(torch.randn(7, 5, dtype=torch.float64, requires_grad=True)) + check_householder(torch.randn(5, 5, dtype=torch.float64, requires_grad=True)) + check_householder(torch.randn(5, 7, dtype=torch.float64, requires_grad=True)) + + +def test_qr_complex_householder() -> None: + check_householder(torch.randn(7, 5, dtype=torch.complex128, requires_grad=True)) + check_householder(torch.randn(5, 5, dtype=torch.complex128, requires_grad=True)) + check_householder(torch.randn(5, 7, dtype=torch.complex128, requires_grad=True)) diff --git a/tests/test_svd.py b/tests/test_svd.py new file mode 100644 index 000000000..14610352e --- /dev/null +++ b/tests/test_svd.py @@ -0,0 +1,40 @@ +"Test SVD" + +import torch +from tat._svd import svd + +# pylint: disable=missing-function-docstring +# pylint: disable=invalid-name + + +def svd_func(A: torch.Tensor) -> torch.Tensor: + U, S, V = svd(A, 1e-10) + return U @ torch.diag(S).to(dtype=A.dtype) @ V + + +def check_svd(A: torch.Tensor) -> None: + m, n = A.size() + U, S, V = svd(A, 1e-10) + assert torch.allclose(U @ torch.diag(S.to(dtype=A.dtype)) @ V, A) + assert torch.allclose(U.H @ U, torch.eye(min(m, n), dtype=A.dtype, device=A.device)) + assert torch.allclose(V @ V.H, torch.eye(min(m, n), dtype=A.dtype, device=A.device)) + grad_check = torch.autograd.gradcheck( + svd_func, + A, + eps=1e-8, + atol=1e-4, + nondet_tol=1e-10, + ) + assert grad_check + + +def test_svd_real() -> None: + check_svd(torch.randn(7, 5, dtype=torch.float64, requires_grad=True)) + check_svd(torch.randn(5, 5, dtype=torch.float64, requires_grad=True)) + check_svd(torch.randn(5, 7, dtype=torch.float64, requires_grad=True)) + + +def test_svd_complex() -> None: + check_svd(torch.randn(7, 5, dtype=torch.complex128, requires_grad=True)) + check_svd(torch.randn(5, 5, dtype=torch.complex128, requires_grad=True)) + check_svd(torch.randn(5, 7, dtype=torch.complex128, requires_grad=True))