From cb107699bee805d17f166ee3384f772fc0812570 Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Wed, 1 Nov 2023 10:10:20 +0800 Subject: [PATCH] Implement prototype for torch based fermionic library. --- .github/workflows/CI.yml | 48 + .gitignore | 4 + LICENSE.md | 675 +++++++++++++ README.md | 3 + pyproject.toml | 31 + tat/__init__.py | 6 + tat/_utility.py | 41 + tat/compat.py | 235 +++++ tat/edge.py | 284 ++++++ tat/tensor.py | 1885 +++++++++++++++++++++++++++++++++++ tests/test_compat.py | 113 +++ tests/test_create_tensor.py | 89 ++ tests/test_edge.py | 33 + 13 files changed, 3447 insertions(+) create mode 100644 .github/workflows/CI.yml create mode 100644 .gitignore create mode 100644 LICENSE.md create mode 100644 README.md create mode 100644 pyproject.toml create mode 100644 tat/__init__.py create mode 100644 tat/_utility.py create mode 100644 tat/compat.py create mode 100644 tat/edge.py create mode 100644 tat/tensor.py create mode 100644 tests/test_compat.py create mode 100644 tests/test_create_tensor.py create mode 100644 tests/test_edge.py diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml new file mode 100644 index 000000000..2eb3ce9db --- /dev/null +++ b/.github/workflows/CI.yml @@ -0,0 +1,48 @@ +name: CI + +on: [push, pull_request] + +jobs: + CI: + + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + - python-version: "3.10" + pytorch-version: "1.12" + - python-version: "3.10" + pytorch-version: "1.13" + - python-version: "3.10" + pytorch-version: "2.0" + - python-version: "3.10" + pytorch-version: "2.1" + + - python-version: "3.11" + pytorch-version: "1.13" + - python-version: "3.11" + pytorch-version: "2.0" + - python-version: "3.11" + pytorch-version: "2.1" + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install requirements + run: | + pip install pylint==2.17 mypy==1.6 pytest==7.4 pytest-cov==4.1 + pip install torch==${{ matrix.pytorch-version }}+cpu --index-url https://download.pytorch.org/whl/cpu + pip install multimethod + - name: Run pylint + run: pylint tat + working-directory: ${{ github.workspace }} + - name: Run mypy + run: mypy tat + working-directory: ${{ github.workspace }} + - name: Run pytest + run: pytest + working-directory: ${{ github.workspace }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..a5d2ad6c8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +.coverage +.mypy_cache +__pycache__ +env \ No newline at end of file diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 000000000..496acdb2a --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,675 @@ +# GNU GENERAL PUBLIC LICENSE + +Version 3, 29 June 2007 + +Copyright (C) 2007 Free Software Foundation, Inc. + + +Everyone is permitted to copy and distribute verbatim copies of this +license document, but changing it is not allowed. + +## Preamble + +The GNU General Public License is a free, copyleft license for +software and other kinds of works. + +The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom +to share and change all versions of a program--to make sure it remains +free software for all its users. We, the Free Software Foundation, use +the GNU General Public License for most of our software; it applies +also to any other work released this way by its authors. You can apply +it to your programs, too. + +When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + +To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you +have certain responsibilities if you distribute copies of the +software, or if you modify it: responsibilities to respect the freedom +of others. + +For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + +Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + +For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + +Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the +manufacturer can do so. This is fundamentally incompatible with the +aim of protecting users' freedom to change the software. The +systematic pattern of such abuse occurs in the area of products for +individuals to use, which is precisely where it is most unacceptable. +Therefore, we have designed this version of the GPL to prohibit the +practice for those products. If such problems arise substantially in +other domains, we stand ready to extend this provision to those +domains in future versions of the GPL, as needed to protect the +freedom of users. + +Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish +to avoid the special danger that patents applied to a free program +could make it effectively proprietary. To prevent this, the GPL +assures that patents cannot be used to render the program non-free. + +The precise terms and conditions for copying, distribution and +modification follow. + +## TERMS AND CONDITIONS + +### 0. Definitions. + +"This License" refers to version 3 of the GNU General Public License. + +"Copyright" also means copyright-like laws that apply to other kinds +of works, such as semiconductor masks. + +"The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + +To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of +an exact copy. The resulting work is called a "modified version" of +the earlier work or a work "based on" the earlier work. + +A "covered work" means either the unmodified Program or a work based +on the Program. + +To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + +To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user +through a computer network, with no transfer of a copy, is not +conveying. + +An interactive user interface displays "Appropriate Legal Notices" to +the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + +### 1. Source Code. + +The "source code" for a work means the preferred form of the work for +making modifications to it. "Object code" means any non-source form of +a work. + +A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + +The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + +The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + +The Corresponding Source need not include anything that users can +regenerate automatically from other parts of the Corresponding Source. + +The Corresponding Source for a work in source code form is that same +work. + +### 2. Basic Permissions. + +All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + +You may make, run and propagate covered works that you do not convey, +without conditions so long as your license otherwise remains in force. +You may convey covered works to others for the sole purpose of having +them make modifications exclusively for you, or provide you with +facilities for running those works, provided that you comply with the +terms of this License in conveying all material for which you do not +control copyright. Those thus making or running the covered works for +you must do so exclusively on your behalf, under your direction and +control, on terms that prohibit them from making any copies of your +copyrighted material outside their relationship with you. + +Conveying under any other circumstances is permitted solely under the +conditions stated below. Sublicensing is not allowed; section 10 makes +it unnecessary. + +### 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + +No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + +When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such +circumvention is effected by exercising rights under this License with +respect to the covered work, and you disclaim any intention to limit +operation or modification of the work as a means of enforcing, against +the work's users, your or third parties' legal rights to forbid +circumvention of technological measures. + +### 4. Conveying Verbatim Copies. + +You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + +You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + +### 5. Conveying Modified Source Versions. + +You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these +conditions: + +- a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. +- b) The work must carry prominent notices stating that it is + released under this License and any conditions added under + section 7. This requirement modifies the requirement in section 4 + to "keep intact all notices". +- c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. +- d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + +A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + +### 6. Conveying Non-Source Forms. + +You may convey a covered work in object code form under the terms of +sections 4 and 5, provided that you also convey the machine-readable +Corresponding Source under the terms of this License, in one of these +ways: + +- a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. +- b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the Corresponding + Source from a network server at no charge. +- c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. +- d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. +- e) Convey the object code using peer-to-peer transmission, + provided you inform other peers where the object code and + Corresponding Source of the work are being offered to the general + public at no charge under subsection 6d. + +A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + +A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, +family, or household purposes, or (2) anything designed or sold for +incorporation into a dwelling. In determining whether a product is a +consumer product, doubtful cases shall be resolved in favor of +coverage. For a particular product received by a particular user, +"normally used" refers to a typical or common use of that class of +product, regardless of the status of the particular user or of the way +in which the particular user actually uses, or expects or is expected +to use, the product. A product is a consumer product regardless of +whether the product has substantial commercial, industrial or +non-consumer uses, unless such uses represent the only significant +mode of use of the product. + +"Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to +install and execute modified versions of a covered work in that User +Product from a modified version of its Corresponding Source. The +information must suffice to ensure that the continued functioning of +the modified object code is in no case prevented or interfered with +solely because modification has been made. + +If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + +The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or +updates for a work that has been modified or installed by the +recipient, or for the User Product in which it has been modified or +installed. Access to a network may be denied when the modification +itself materially and adversely affects the operation of the network +or violates the rules and protocols for communication across the +network. + +Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + +### 7. Additional Terms. + +"Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + +When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + +Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders +of that material) supplement the terms of this License with terms: + +- a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or +- b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or +- c) Prohibiting misrepresentation of the origin of that material, + or requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or +- d) Limiting the use for publicity purposes of names of licensors + or authors of the material; or +- e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or +- f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions + of it) with contractual assumptions of liability to the recipient, + for any liability that these contractual assumptions directly + impose on those licensors and authors. + +All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + +If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + +Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; the +above requirements apply either way. + +### 8. Termination. + +You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + +However, if you cease all violation of this License, then your license +from a particular copyright holder is reinstated (a) provisionally, +unless and until the copyright holder explicitly and finally +terminates your license, and (b) permanently, if the copyright holder +fails to notify you of the violation by some reasonable means prior to +60 days after the cessation. + +Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + +Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + +### 9. Acceptance Not Required for Having Copies. + +You are not required to accept this License in order to receive or run +a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + +### 10. Automatic Licensing of Downstream Recipients. + +Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + +An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + +You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + +### 11. Patents. + +A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + +A contributor's "essential patent claims" are all patent claims owned +or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + +Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + +In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + +If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + +If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + +A patent license is "discriminatory" if it does not include within the +scope of its coverage, prohibits the exercise of, or is conditioned on +the non-exercise of one or more of the rights that are specifically +granted under this License. You may not convey a covered work if you +are a party to an arrangement with a third party that is in the +business of distributing software, under which you make payment to the +third party based on the extent of your activity of conveying the +work, and under which the third party grants, to any of the parties +who would receive the covered work from you, a discriminatory patent +license (a) in connection with copies of the covered work conveyed by +you (or copies made from those copies), or (b) primarily for and in +connection with specific products or compilations that contain the +covered work, unless you entered into that arrangement, or that patent +license was granted, prior to 28 March 2007. + +Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + +### 12. No Surrender of Others' Freedom. + +If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under +this License and any other pertinent obligations, then as a +consequence you may not convey it at all. For example, if you agree to +terms that obligate you to collect a royalty for further conveying +from those to whom you convey the Program, the only way you could +satisfy both those terms and this License would be to refrain entirely +from conveying the Program. + +### 13. Use with the GNU Affero General Public License. + +Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + +### 14. Revised Versions of this License. + +The Free Software Foundation may publish revised and/or new versions +of the GNU General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in +detail to address new problems or concerns. + +Each version is given a distinguishing version number. If the Program +specifies that a certain numbered version of the GNU General Public +License "or any later version" applies to it, you have the option of +following the terms and conditions either of that numbered version or +of any later version published by the Free Software Foundation. If the +Program does not specify a version number of the GNU General Public +License, you may choose any version ever published by the Free +Software Foundation. + +If the Program specifies that a proxy can decide which future versions +of the GNU General Public License can be used, that proxy's public +statement of acceptance of a version permanently authorizes you to +choose that version for the Program. + +Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + +### 15. Disclaimer of Warranty. + +THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT +WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND +PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE +DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR +CORRECTION. + +### 16. Limitation of Liability. + +IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR +CONVEYS THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, +INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES +ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT +NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR +LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM +TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER +PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + +### 17. Interpretation of Sections 15 and 16. + +If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + +END OF TERMS AND CONDITIONS + +## How to Apply These Terms to Your New Programs + +If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these +terms. + +To do so, attach the following notices to the program. It is safest to +attach them to the start of each source file to most effectively state +the exclusion of warranty; and each file should have at least the +"copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper +mail. + +If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands \`show w' and \`show c' should show the +appropriate parts of the General Public License. Of course, your +program's commands might be different; for a GUI interface, you would +use an "about box". + +You should also get your employer (if you work as a programmer) or +school, if any, to sign a "copyright disclaimer" for the program, if +necessary. For more information on this, and how to apply and follow +the GNU GPL, see . + +The GNU General Public License does not permit incorporating your +program into proprietary programs. If your program is a subroutine +library, you may consider it more useful to permit linking proprietary +applications with the library. If this is what you want to do, use the +GNU Lesser General Public License instead of this License. But first, +please read . diff --git a/README.md b/README.md new file mode 100644 index 000000000..3370c23ae --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# TAT + +A Fermionic tensor library based on pytorch. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..90c0398a4 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,31 @@ +[project] +name = "tat" +version = "0.4.0" +authors = [ + {email = "zh970205@mail.ustc.edu.cn", name = "Hao Zhang"} +] +description = "A Fermionic tensor library based on pytorch." +readme = "README.md" +requires-python = ">=3.10" +license = {text = "GPL-3.0-or-later"} +dependencies = [ + "multimethod>=1.9", + "torch>=1.12", +] + +[tool.pylint] +max-line-length = 120 +generated-members = "torch.*" + +[tool.yapf] +based_on_style = "google" +column_limit = 120 + +[tool.mypy] +check_untyped_defs = true +disallow_untyped_defs = true + +[tool.pytest.ini_options] +pythonpath = "." +testpaths = ["tests",] +addopts = "--cov=tat" diff --git a/tat/__init__.py b/tat/__init__.py new file mode 100644 index 000000000..9f8cc4bcb --- /dev/null +++ b/tat/__init__.py @@ -0,0 +1,6 @@ +""" +The tat is a Fermionic tensor library based on pytorch. +""" + +from .edge import Edge +from .tensor import Tensor diff --git a/tat/_utility.py b/tat/_utility.py new file mode 100644 index 000000000..1e890d29e --- /dev/null +++ b/tat/_utility.py @@ -0,0 +1,41 @@ +""" +Some internal utility used by tat. +""" + +import torch + +# pylint: disable=missing-function-docstring +# pylint: disable=no-else-return + + +def unsqueeze(tensor: torch.Tensor, index: int, rank: int) -> torch.Tensor: + return tensor.reshape([-1 if i == index else 1 for i in range(rank)]) + + +def neg_symmetry(tensor: torch.Tensor) -> torch.Tensor: + if tensor.dtype is torch.bool: + return tensor + else: + return -tensor + + +def add_symmetry(tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor: + if tensor_1.dtype is torch.bool: + return torch.logical_xor(tensor_1, tensor_2) + else: + return torch.add(tensor_1, tensor_2) + + +def zero_symmetry(tensor: torch.Tensor) -> torch.Tensor: + # pylint: disable=singleton-comparison + if tensor.dtype is torch.bool: + return tensor == False + else: + return tensor == 0 + + +def parity(tensor: torch.Tensor) -> torch.Tensor: + if tensor.dtype is torch.bool: + return tensor + else: + return tensor % 2 != 0 diff --git a/tat/compat.py b/tat/compat.py new file mode 100644 index 000000000..a64a8007c --- /dev/null +++ b/tat/compat.py @@ -0,0 +1,235 @@ +""" +This file implements a compat layer for legacy TAT interface. +""" + +from __future__ import annotations +import typing +from multimethod import multimethod +import torch +from .edge import Edge as E +from .tensor import Tensor as T + +# pylint: disable=too-few-public-methods +# pylint: disable=too-many-instance-attributes + + +class CompatSymmetry: + """ + The common Symmetry namespace. + """ + + def __init__(self: CompatSymmetry, fermion: tuple[bool, ...], dtypes: tuple[torch.dtype, ...]) -> None: + # This create fake module like TAT.No, TAT.Z2 or similar things, it need to specify the symmetry attributes. + # symmetry is set by two attributes: fermion and dtypes. + self.fermion: tuple[bool, ...] = fermion + self.dtypes: tuple[torch.dtype, ...] = dtypes + + # pylint: disable=invalid-name + self.S: CompatScalar + self.D: CompatScalar + self.C: CompatScalar + self.Z: CompatScalar + self.float32: CompatScalar + self.float64: CompatScalar + self.float: CompatScalar + self.complex64: CompatScalar + self.complex128: CompatScalar + self.complex: CompatScalar + + # In old TAT, something like TAT.No.D is a sub module for tensor with specific scalar type. + # In this compat library, it is implemented by another fake module: CompatScalar. + self.S = self.float32 = CompatScalar(self, torch.float32) + self.D = self.float64 = self.float = CompatScalar(self, torch.float64) + self.C = self.complex64 = CompatScalar(self, torch.complex64) + self.Z = self.complex128 = self.complex = CompatScalar(self, torch.complex128) + + def _parse_segments(self: CompatSymmetry, segments: list) -> tuple[tuple[torch.Tensor, ...], int]: + # In TAT, user could use [Sym] or [(Sym, Size)] to set segments of a edge, where [(Sym, Size)] is nothing but + # the symmetry and size of every blocks. While [Sym] acts like [(Sym, 1)], so try to treat input as + # [(Sym, Size)] First, if error raised, convert it from [Sym] to [(Sym, 1)] and try again. + try: + # try [(Sym, Size)] first + return self._parse_segments_kernel(segments) + except TypeError: + # Cannot unpack is a type error, value[index] is a type error, too. So only catch TypeError here. + # convert [Sym] to [(Sym, Size)] + return self._parse_segments_kernel([(sym, 1) for sym in segments]) + # This function return the symmetry tuple and dimension + + def _parse_segments_kernel(self: CompatSymmetry, + segments: list[tuple[typing.Any, int]]) -> tuple[tuple[torch.Tensor, ...], int]: + # [(Sym, Size)] for every element + dimension = sum(dim for _, dim in segments) + symmetry = tuple( + torch.tensor( + # tat.Edge need torch.Tensor as its symmetry, convert it to torch.Tensor with specific dtype. + sum( + # Concat all segment for this subsymmetry from an empty list + # Every segment is just sym[index] * dim, sometimes sym may be sub symmetry itself directly instead + # of tuple of sub symmetry, so call an utility function _parse_segments_get_subsymmetry here. + ([self._parse_segments_get_subsymmetry(sym, index)] * dim + for sym, dim in segments), + [], + ), + dtype=sub_symmetry, + ) + # Generate subsymmetry one by one + for index, sub_symmetry in enumerate(self.dtypes)) + return symmetry, dimension + + def _parse_segments_get_subsymmetry(self: CompatSymmetry, sym: object, index: int) -> object: + # Most of time, symmetry is a tuple of sub symmetry + # But if there is only one sub symmetry in the symmetry, it could not be a tuple but subsymmetry itself. + # pylint: disable=no-else-return + if isinstance(sym, tuple): + # If it is tuple, there is no need to do any other check + return sym[index] + else: + # If it is not tuple, it should be sub symmetry directly, so this symmetry only should own single sub + # symmetry, check it. + if len(self.fermion) == 1: + return sym + else: + raise TypeError(f"{sym=} is not subscriptable") + + @multimethod + def Edge(self: CompatSymmetry, dimension: int) -> E: + """ + Create edge with compat interface. + + It may be created by + 1. Edge(dimension) create trivial symmetry with specified dimension. + 2. Edge(segments, arrow) create with given segments and arrow. + 3. Edge(segments_arrow_tuple) create with a tuple of segments and arrow. + """ + # pylint: disable=invalid-name + # Generate a trivial symmetry tuple. In this tuple, every sub symmetry is a torch.zeros tensor with specific + # dtype and the same dimension. + symmetry = tuple(torch.zeros(dimension, dtype=sub_symmetry) for sub_symmetry in self.dtypes) + return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=False) + + @Edge.register + def _(self: CompatSymmetry, segments: list, arrow: bool = False) -> E: + symmetry, dimension = self._parse_segments(segments) + return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=arrow) + + @Edge.register + def _(self: CompatSymmetry, segments_and_bool: tuple[list, bool]) -> E: + segments, arrow = segments_and_bool + symmetry, dimension = self._parse_segments(segments) + return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=arrow) + + +class CompatScalar: + """ + The common Scalar namespace. + """ + + def __init__(self: CompatScalar, symmetry: CompatSymmetry, dtype: torch.dtype) -> None: + # This is fake module like TAT.No.D, TAT.Fermi.complex, so it records the parent symmetry information and its + # own dtype. + self.symmetry: CompatSymmetry = symmetry + self.dtype: torch.dtype = dtype + + @multimethod + def Tensor(self: CompatScalar, names: list[str], edges: list) -> T: + """ + Create tensor with compat names and edges. + + It may be create by + 1. Tensor(names, edges) The most used interface. + 2. Tensor() Create a rank-0 tensor, fill with number 1. + 3. Tensor(number, names=[], edge_symmetry=[], edge_arrow=[]) Create a size-1 tensor, with specified edge, and + filled with specified number. + """ + # pylint: disable=invalid-name + return T( + tuple(names), + tuple(self.symmetry.Edge(edge) for edge in edges), + fermion=self.symmetry.fermion, + dtypes=self.symmetry.dtypes, + dtype=self.dtype, + ) + + @Tensor.register + def _(self: CompatScalar) -> T: + result = T( + (), + (), + fermion=self.symmetry.fermion, + dtypes=self.symmetry.dtypes, + dtype=self.dtype, + ) + # To set element of rank-0 pytorch tensor, reshape it to rank-0, length-1 vector first and set it. + result.data.reshape([-1])[0] = 1 + return result + + @Tensor.register + def _( + self: CompatScalar, + number: typing.Any, + names: list[str] | None = None, + edge_symmetry: list | None = None, + edge_arrow: list[bool] | None = None, + ) -> T: + # Create high rank tensor with only one element + if names is None: + names = [] + if edge_symmetry is None: + edge_symmetry = [None for _ in names] + if edge_arrow is None: + edge_arrow = [False for _ in names] + result = T( + tuple(names), + tuple( + # Create edge for every rank, given the only symmetry(maybe None) and arrow. + E( + fermion=self.symmetry.fermion, + dtypes=self.symmetry.dtypes, + # For every edge, its symmetry is a tuple of all sub symmetry. + symmetry=tuple( + # For every sub symmetry, get the only symmetry for it, since dimension of all edge is 1. + # It should be noticed that the symmetry may be None, tuple or sub symmetry directly. + torch.tensor([self._create_size1_get_subsymmetry(symmetry, index)], dtype=dtype) + for index, dtype in enumerate(self.symmetry.dtypes)), + dimension=1, + arrow=arrow, + ) + for symmetry, arrow in zip(edge_symmetry, edge_arrow)), + fermion=self.symmetry.fermion, + dtypes=self.symmetry.dtypes, + dtype=self.dtype, + ) + # To set element of rank-0 pytorch tensor, reshape it to rank-0, length-1 vector first and set it. + result.data.reshape([-1])[0] = number + return result + + def _create_size1_get_subsymmetry(self: CompatScalar, sym: object, index: int) -> object: + # pylint: disable=no-else-return + # sym may be None, tuple or sub symmetry directly. + if sym is None: + # If is None, user may want to create symmetric edge with trivial symmetry, which should be 0 for int and + # False for bool, always return 0 here, since it will be converted to correct type by torch.tensor. + return 0 + elif isinstance(sym, tuple): + # If it is tuple, there is no need to do any other check + return sym[index] + else: + # If it is not tuple, it should be sub symmetry directly, so this symmetry only should own single sub + # symmetry, check it. + if len(self.symmetry.fermion) == 1: + return sym + else: + raise TypeError(f"{sym=} is not subscriptable") + + +# Create fake sub module for all symmetry compiled in old version TAT +No: CompatSymmetry = CompatSymmetry(fermion=(), dtypes=()) +Z2: CompatSymmetry = CompatSymmetry(fermion=(False,), dtypes=(torch.bool,)) +U1: CompatSymmetry = CompatSymmetry(fermion=(False,), dtypes=(torch.int,)) +Fermi: CompatSymmetry = CompatSymmetry(fermion=(True,), dtypes=(torch.int,)) +FermiZ2: CompatSymmetry = CompatSymmetry(fermion=(True, False), dtypes=(torch.int, torch.bool)) +FermiU1: CompatSymmetry = CompatSymmetry(fermion=(True, False), dtypes=(torch.int, torch.int)) +Parity: CompatSymmetry = CompatSymmetry(fermion=(True,), dtypes=(torch.bool,)) +FermiFermi: CompatSymmetry = CompatSymmetry(fermion=(True, True), dtypes=(torch.int, torch.int)) +Normal: CompatSymmetry = No diff --git a/tat/edge.py b/tat/edge.py new file mode 100644 index 000000000..f88ce9be6 --- /dev/null +++ b/tat/edge.py @@ -0,0 +1,284 @@ +""" +This file contains the definition of tensor edge. +""" + +from __future__ import annotations +import functools +import operator +import torch +from . import _utility + + +class Edge: + """ + The edge type of tensor. + """ + + __slots__ = "_fermion", "_dtypes", "_symmetry", "_dimension", "_arrow", "_parity" + + @property + def fermion(self: Edge) -> tuple[bool, ...]: + """ + A tuple records whether every sub symmetry is fermionic. Its length is the number of sub symmetry. + """ + return self._fermion + + @property + def dtypes(self: Edge) -> tuple[torch.dtype, ...]: + """ + A tuple records the basic dtype of every sub symmetry. Its length is the number of sub symmetry. + """ + return self._dtypes + + @property + def symmetry(self: Edge) -> tuple[torch.Tensor, ...]: + """ + A tuple containing all symmetry of this edge. Its length is the number of sub symmetry. Every element of it is a + sub symmetry. + """ + return self._symmetry + + @property + def dimension(self: Edge) -> int: + """ + The dimension of this edge. + """ + return self._dimension + + @property + def arrow(self: Edge) -> bool: + """ + The arrow of this edge. + """ + return self._arrow + + @property + def parity(self: Edge) -> torch.Tensor: + """ + The parity of this edge. + """ + return self._parity + + def __init__( + self: Edge, + *, + fermion: tuple[bool, ...] | None = None, + dtypes: tuple[torch.dtype, ...] | None = None, + symmetry: tuple[torch.Tensor, ...] | None = None, + dimension: int | None = None, + arrow: bool | None = None, + **kwargs: torch.Tensor, + ) -> None: + """ + Create an edge with essential information. + + Examples: + - Edge(dimension=5) + - Edge(symmetry=(torch.tensor([False, False, True, True]),)) + - Edge(fermion=(False, True), symmetry=(torch.tensor([False, True]), torch.tensor([False, True])), arrow=True) + + Parameters + ---------- + fermion : tuple[bool, ...], optional + Whether each sub symmetry is fermionic symmetry, its length should be the same to symmetry. But it could be + left empty, if so, a total bosonic edge will be created. + dtypes : tuple[torch.dtype, ...], optional + The basic dtype to identify each sub symmetry, its length should be the same to symmetry, and it is nothing + but the dtypes of each tensor in the symmetry. It could be left empty, if so, it will be derived from + symmetry. + symmetry : tuple[torch.Tensor, ...], optional + The symmetry information of every sub symmetry, each of sub symmetry should be a one dimensional tensor with + the same length dimension, and their dtype should be integral type, aka, int or bool. + dimension : int, optional + The dimension of the edge, if not specified, dimension will be detected from symmetry. + arrow : bool, optional + The arrow direction of the edge, it is essential for fermionic edge, aka, an edge with fermionic sub + symmetry. + """ + # Symmetry could be left empty to create no symmetry edge + if symmetry is None: + symmetry = () + + # Fermion could be empty if it is total bosonic edge + if fermion is None: + fermion = tuple(False for _ in symmetry) + + # Dtypes could be empty and derived from symmetry + if dtypes is None: + dtypes = tuple(sub_symmetry.dtype for sub_symmetry in symmetry) + # Check dtype is compatible with symmetry + assert all(sub_symmetry.dtype is sub_dtype for sub_symmetry, sub_dtype in zip(symmetry, dtypes)) + # Check dtype is valid, aka, bool or int + assert all(not (sub_symmetry.is_floating_point() or sub_symmetry.is_complex()) for sub_symmetry in symmetry) + + # The fermion, dtypes and symmetry information should have the same length + assert len(fermion) == len(dtypes) == len(symmetry) + + # If dimension not set, get dimension from symmetry + if dimension is None: + dimension = len(symmetry[0]) + # Check if the dimensions of different sub_symmetry mismatch + assert all(sub_symmetry.size() == (dimension,) for sub_symmetry in symmetry) + + if arrow is None: + # Arrow not set, it should be bosonic edge. + arrow = False + assert not any(fermion) + + self._fermion: tuple[bool, ...] = fermion + self._dtypes: tuple[torch.dtype, ...] = dtypes + self._symmetry: tuple[torch.Tensor, ...] = symmetry + self._dimension: int = dimension + self._arrow: bool = arrow + + self._parity: torch.Tensor + if "parity" in kwargs: + self._parity = kwargs.pop("parity") + assert self.parity.size() == (self.dimension,) + assert self.parity.dtype is torch.bool + else: + self._parity = self._generate_parity() + assert not kwargs + + def _generate_parity(self: Edge) -> torch.Tensor: + return functools.reduce( + # Reduce sub parity for all sub symmetry by logical xor + torch.logical_xor, + ( + # The parity of sub symmetry + _utility.parity(sub_symmetry) + # Loop all sub symmetry + for sub_symmetry, sub_fermion in zip(self.symmetry, self.fermion) + # But only reduce if it is fermion sub symmetry + if sub_fermion), + # Reduce with start as tensor filled with False + torch.zeros(self.dimension, dtype=torch.bool), + ) + + def conjugate(self: Edge) -> Edge: + """ + Get the conjugated edge. + + Returns + ------- + Edge + The conjugated edge. + """ + # The only two difference of conjguated edge is symmetry and arrow + return Edge( + fermion=self.fermion, + dtypes=self.dtypes, + symmetry=tuple( + _utility.neg_symmetry(sub_symmetry) # bool -> same, int -> neg + for sub_symmetry in self.symmetry), + dimension=self.dimension, + arrow=not self.arrow, + parity=self.parity, + ) + + def __eq__(self: Edge, other: object) -> bool: + if not isinstance(other, Edge): + return NotImplemented + return ( + # Compare int dimension and bool arrow first since they are fast to compare + self.dimension == other.dimension and + # But even if arrow are differnt, if it is bonsonic edge, it is also ok + (self.arrow == other.arrow or not any(self.fermion)) and + # Then the tuple of bool are compared + self.fermion == other.fermion and + # Then the tuple of dtypes are compared + self.dtypes == other.dtypes and + # All of symmetries are compared at last, since it is biggest + all( + torch.equal(self_sub_symmetry, other_sub_symmetry) + for self_sub_symmetry, other_sub_symmetry in zip(self.symmetry, other.symmetry))) + + def __str__(self: Edge) -> str: + # pylint: disable=no-else-return + if any(self.fermion): + # Fermionic edge + return f"(dimension={self.dimension}, arrow={self.arrow}, fermion={self.fermion}, symmetry={self.symmetry})" + elif self.fermion: + # Bosonic edge + return f"(dimension={self.dimension}, symmetry={self.symmetry})" + else: + # Trivial edge + return f"(dimension={self.dimension})" + + def __repr__(self: Edge) -> str: + return f"Edge{self.__str__()}" + + @staticmethod + def merge_edges( + edges: tuple[Edge, ...], + fermion: tuple[bool, ...] | None = None, + dtypes: tuple[torch.dtype, ...] | None = None, + arrow: bool | None = None, + ) -> Edge: + """ + Merge several edges into one edge. + + Parameters + ---------- + edges : tuple[Edge, ...] + The edges to be merged. + fermion : tuple[bool, ...], optional + Whether each sub symmetry is fermionic, it could be left empty to derive from edges + dtypes : tuple[torch.dtype, ...], optional + The base type of sub symmetry, it could be left empty to derive from edges + arrow : bool, optional + The arrow of all the edges, it is useful if edges is empty. + + Returns + ------- + Edge + The result edge merged by edges. + """ + # If fermion not set, get it from edges + if fermion is None: + fermion = edges[0].fermion + # All edge should share the same fermion + assert all(fermion == edge.fermion for edge in edges) + # If dtypes not set, get it from edges + if dtypes is None: + dtypes = edges[0].dtypes + # All edge should share the same dtypes + assert all(dtypes == edge.dtypes for edge in edges) + # If arrow set, check it directly, if not set, set to False or get from edges + if arrow is None: + if any(fermion): + # It is fermionic edge. + arrow = edges[0].arrow + else: + # It is bosonic edge, set to False directly since it is useless. + arrow = False + # All edge should share the same arrow for fermionic edge + assert (not any(fermion)) or all(arrow == edge.arrow for edge in edges) + + rank = len(edges) + # Merge edge + dimension = functools.reduce(operator.mul, (edge.dimension for edge in edges), 1) + symmetry = tuple( + # Every merged sub symmetry is calculated by reduce and flatten + functools.reduce( + # The reduce operator depend on the dtype of this sub symmetry + _utility.add_symmetry, + ( + # The sub symmetry of every edge will be reshape to be reduced. + _utility.unsqueeze(edge.symmetry[sub_symmetry_index], current_index, rank) + # The sub symmetry of every edge is reduced one by one + for current_index, edge in enumerate(edges)), + # Reduce from a rank-0 tensor + torch.zeros([], dtype=sub_symmetry_dtype), + ).reshape([-1]) + # Merge every sub symmetry one by one + for sub_symmetry_index, sub_symmetry_dtype in enumerate(dtypes)) + + # parity not set here since it need recalculation + return Edge( + fermion=fermion, + dtypes=dtypes, + symmetry=symmetry, + dimension=dimension, + arrow=arrow, + ) diff --git a/tat/tensor.py b/tat/tensor.py new file mode 100644 index 000000000..fcb6789e8 --- /dev/null +++ b/tat/tensor.py @@ -0,0 +1,1885 @@ +""" +This file defined the core tensor type for tat package. +""" + +from __future__ import annotations +import typing +import operator +import functools +from multimethod import multimethod +import torch +from . import _utility +from .edge import Edge + +# pylint: disable=too-many-public-methods +# pylint: disable=too-many-lines + + +class Tensor: + """ + The main tensor type, which wraps pytorch tensor and provides edge names and Fermionic functions. + """ + + __slots__ = "_fermion", "_dtypes", "_names", "_edges", "_data", "_mask" + + def __str__(self: Tensor) -> str: + return f"(names={self.names}, edges={self.edges}, data={self.data})" + + def __repr__(self: Tensor) -> str: + return f"Tensor(names={self.names}, edges={self.edges})" + + @property + def fermion(self: Tensor) -> tuple[bool, ...]: + """ + A tuple records whether every sub symmetry is fermionic. Its length is the number of sub symmetry. + """ + return self._fermion + + @property + def dtypes(self: Tensor) -> tuple[torch.dtype, ...]: + """ + A tuple records the basic dtype of every sub symmetry. Its length is the number of sub symmetry. + """ + return self._dtypes + + @property + def names(self: Tensor) -> tuple[str, ...]: + """ + The edge names of this tensor. + """ + return self._names + + @property + def edges(self: Tensor) -> tuple[Edge, ...]: + """ + The edges information of this tensor. + """ + return self._edges + + @property + def data(self: Tensor) -> torch.Tensor: + """ + The content data of this tensor. + """ + return self._data + + @property + def mask(self: Tensor) -> torch.Tensor: + """ + The content data mask of this tensor. + """ + return self._mask + + @property + def rank(self: Tensor) -> int: + """ + The rank of this tensor. + """ + return len(self._names) + + @property + def dtype(self: Tensor) -> torch.dtype: + """ + The data type of the content in this tensor. + """ + return self.data.dtype + + @property + def btype(self: Tensor) -> str: + """ + The data type of the content in this tensor, represented in BLAS/LAPACK convention. + """ + if self.dtype is torch.float32: + return 'S' + if self.dtype is torch.float64: + return 'D' + if self.dtype is torch.complex64: + return 'C' + if self.dtype is torch.complex128: + return 'Z' + return '?' + + @property + def is_complex(self: Tensor) -> bool: + """ + Whether it is a complex tensor + """ + return self.dtype.is_complex + + @property + def is_real(self: Tensor) -> bool: + """ + Whether it is a real tensor + """ + return self.dtype.is_floating_point + + def edge_by_name(self: Tensor, name: str) -> Edge: + """ + Get edge by the edge name of this tensor. + + Parameters + ---------- + name : str + The given edge name. + + Returns + ------- + Edge + The edge with the given edge name. + """ + assert name in self.names + return self.edges[self.names.index(name)] + + def _arithmetic_operator(self: Tensor, other: object, operate: typing.Callable) -> Tensor: + new_data: torch.Tensor + if isinstance(other, Tensor): + # If it is tensor, check same shape and transpose before calculating. + assert self.same_shape_with(other) + new_data = operate(self.data, other.transpose(self.names).data) + if operate is torch.div: + # In div, it may generate nan + new_data = torch.where(self.mask, new_data, torch.zeros([], dtype=self.dtype)) + else: + # Otherwise treat other as a scalar, mask should be applied later. + new_data = operate(self.data, other) + new_data = torch.where(self.mask, new_data, torch.zeros([], dtype=self.dtype)) + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=new_data, + mask=self.mask, + ) + + def __add__(self: Tensor, other: object) -> Tensor: + return self._arithmetic_operator(other, torch.add) + + def __sub__(self: Tensor, other: object) -> Tensor: + return self._arithmetic_operator(other, torch.sub) + + def __mul__(self: Tensor, other: object) -> Tensor: + return self._arithmetic_operator(other, torch.mul) + + def __truediv__(self: Tensor, other: object) -> Tensor: + return self._arithmetic_operator(other, torch.div) + + def _right_arithmetic_operator(self: Tensor, other: object, operate: typing.Callable) -> Tensor: + new_data: torch.Tensor + if isinstance(other, Tensor): + # If it is tensor, check same shape and transpose before calculating. + assert self.same_shape_with(other) + new_data = operate(other.transpose(self.names).data, self.data) + if operate is torch.div: + # In div, it may generate nan + new_data = torch.where(self.mask, new_data, torch.zeros([], dtype=self.dtype)) + else: + # Otherwise treat other as a scalar, mask should be applied later. + new_data = operate(other, self.data) + new_data = torch.where(self.mask, new_data, torch.zeros([], dtype=self.dtype)) + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=new_data, + mask=self.mask, + ) + + def __radd__(self: Tensor, other: object) -> Tensor: + return self._right_arithmetic_operator(other, torch.add) + + def __rsub__(self: Tensor, other: object) -> Tensor: + return self._right_arithmetic_operator(other, torch.sub) + + def __rmul__(self: Tensor, other: object) -> Tensor: + return self._right_arithmetic_operator(other, torch.mul) + + def __rtruediv__(self: Tensor, other: object) -> Tensor: + return self._right_arithmetic_operator(other, torch.div) + + def _inplace_arithmetic_operator(self: Tensor, other: object, operate: typing.Callable) -> Tensor: + if isinstance(other, Tensor): + # If it is tensor, check same shape and transpose before calculating. + assert self.same_shape_with(other) + operate(self.data, other.transpose(self.names).data, out=self.data) + if operate is torch.div: + # In div, it may generate nan + torch.where(self.mask, self.data, torch.zeros([], dtype=self.dtype), out=self.data) + else: + # Otherwise treat other as a scalar, mask should be applied later. + operate(self.data, other, out=self.data) + torch.where(self.mask, self.data, torch.zeros([], dtype=self.dtype), out=self.data) + return self + + def __iadd__(self: Tensor, other: object) -> Tensor: + return self._inplace_arithmetic_operator(other, torch.add) + + def __isub__(self: Tensor, other: object) -> Tensor: + return self._inplace_arithmetic_operator(other, torch.sub) + + def __imul__(self: Tensor, other: object) -> Tensor: + return self._inplace_arithmetic_operator(other, torch.mul) + + def __itruediv__(self: Tensor, other: object) -> Tensor: + return self._inplace_arithmetic_operator(other, torch.div) + + def __pos__(self: Tensor) -> Tensor: + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=+self.data, + mask=self.mask, + ) + + def __neg__(self: Tensor) -> Tensor: + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=-self.data, + mask=self.mask, + ) + + def __float__(self: Tensor) -> float: + return float(self.data) + + def __complex__(self: Tensor) -> complex: + return complex(self.data) + + def norm(self: Tensor, order: typing.Any) -> float: + """ + Get the norm of the tensor, this function will flatten tensor first before calculate norm. + + Parameters + ---------- + order + The order of norm. + + Returns + ------- + float + The norm of the tensor. + """ + return torch.linalg.vector_norm(self.data, ord=order) + + def norm_max(self: Tensor) -> float: + "max norm" + return self.norm(+torch.inf) + + def norm_min(self: Tensor) -> float: + "min norm" + return self.norm(-torch.inf) + + def norm_num(self: Tensor) -> float: + "0-norm" + return self.norm(0) + + def norm_sum(self: Tensor) -> float: + "1-norm" + return self.norm(1) + + def norm_2(self: Tensor) -> float: + "2-norm" + return self.norm(2) + + def copy(self: Tensor) -> Tensor: + """ + Get a copy of this tensor + + Returns + ------- + Tensor + The copy of this tensor + """ + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=torch.clone(self.data), + mask=self.mask, + ) + + def __copy__(self: Tensor) -> Tensor: + return self.copy() + + def __deepcopy__(self: Tensor, _: typing.Any = None) -> Tensor: + return self.copy() + + def same_shape(self: Tensor) -> Tensor: + """ + Get a tensor with same shape to this tensor + + Returns + ------- + Tensor + A new tensor with the same shape to this tensor + """ + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=torch.zeros_like(self.data), + mask=self.mask, + ) + + def zero_(self: Tensor) -> Tensor: + """ + Set all element to zero in this tensor + + Returns + ------- + Tensor + Return this tensor itself. + """ + self.data.zero_() + return self + + def sqrt(self: Tensor) -> Tensor: + """ + Get the sqrt of the tensor. + + Returns + ------- + Tensor + The sqrt of this tensor. + """ + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=torch.sqrt(torch.abs(self.data)), + mask=self.mask, + ) + + def reciprocal(self: Tensor) -> Tensor: + """ + Get the reciprocal of the tensor. + + Returns + ------- + Tensor + The reciprocal of this tensor. + """ + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=torch.where(self.data == 0, self.data, 1 / self.data), + mask=self.mask, + ) + + def range_(self: Tensor, first: typing.Any = 0, step: typing.Any = 1) -> Tensor: + """ + A useful function to generate simple data in tensor for test. + + Parameters + ---------- + first, step + Parameters to generate data. + + Returns + ------- + Tensor + Returns the tensor itself. + """ + data = torch.cumsum(self.mask.reshape([-1]), dim=0, dtype=self.data.dtype).reshape(self.data.size()) + data = (data - 1) * step + first + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=data, + mask=self.mask, + ) + + def to(self: Tensor, new_type: typing.Any) -> Tensor: + """ + Convert this tensor to other scalar type. + + Parameters + ---------- + new_type + The scalar data type of the new tensor. + """ + # pylint: disable=invalid-name + if isinstance(new_type, str): + if new_type in ["float32", "S"]: + new_type = torch.float32 + elif new_type in ["float64", "float", "D"]: + new_type = torch.float64 + elif new_type in ["complex64", "C"]: + new_type = torch.complex64 + elif new_type in ["complex128", "complex", "Z"]: + new_type = torch.complex128 + return Tensor( + names=self.names, + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=self.data.to(new_type), + mask=self.mask, + ) + + def __init__( + self: Tensor, + names: tuple[str, ...], + edges: tuple[Edge, ...], + *, + dtype: torch.dtype | None = None, + fermion: tuple[bool, ...] | None = None, + dtypes: tuple[torch.dtype, ...] | None = None, + **kwargs: torch.Tensor, + ) -> None: + """ + Create a tensor with specific shape. + + Parameters + ---------- + names : tuple[str, ...] + The edge names of the tensor, which length is just the tensor rank. + edges : tuple[Edge, ...] + The detail information of each edge, which length is just the tensor rank. + dtype : torch.dtype, optional + The dtype of the tensor, left it empty to let pytorch choose default dtype. + fermion : tuple[bool, ...], optional + Whether each sub symmetry is fermionic, it could be left empty to derive from edges + dtypes : tuple[torch.dtype, ...], optional + The base type of sub symmetry, it could be left empty to derive from edges + """ + # Check the rank is correct in names and edges + assert len(names) == len(edges) + # Check whether there are duplicated names + assert len(set(names)) == len(names) + # If fermion not set, get it from edges + if fermion is None: + fermion = edges[0].fermion + # If dtypes not set, get it from edges + if dtypes is None: + dtypes = edges[0].dtypes + # Check if fermion is correct + assert all(edge.fermion == fermion for edge in edges) + # Check if dtypes is correct + assert all(edge.dtypes == dtypes for edge in edges) + + self._fermion: tuple[bool, ...] = fermion + self._dtypes: tuple[torch.dtype, ...] = dtypes + self._names: tuple[str, ...] = names + self._edges: tuple[Edge, ...] = edges + + self._data: torch.Tensor + if "data" in kwargs: + self._data = kwargs.pop("data") + assert self.data.size() == tuple(edge.dimension for edge in self.edges) + assert dtype is None or self.data.dtype is dtype + else: + if dtype is None: + self._data = torch.zeros([edge.dimension for edge in self.edges]) + else: + self._data = torch.zeros([edge.dimension for edge in self.edges], dtype=dtype) + self._mask: torch.Tensor + if "mask" in kwargs: + self._mask = kwargs.pop("mask") + assert self.mask.size() == self.data.size() + assert self.mask.dtype is torch.bool + else: + self._mask = self._generate_mask() + assert not kwargs + + def _generate_mask(self: Tensor) -> torch.Tensor: + return functools.reduce( + # Mask is valid if all sub symmetry give valid mask. + torch.logical_and, + ( + # The mask is valid if total symmetry is False or total symmetry is 0 + _utility.zero_symmetry( + # total sub symmetry is calculated by reduce + functools.reduce( + # The reduce operator depend on the dtype of this sub symmetry + _utility.add_symmetry, + ( + # The sub symmetry of every edge will be reshape to be reduced. + _utility.unsqueeze(edge.symmetry[sub_symmetry_index], current_index, self.rank) + # The sub symmetry of every edge is reduced one by one + for current_index, edge in enumerate(self.edges)), + # Reduce from a rank-0 tensor + torch.zeros([], dtype=sub_symmetry_dtype), + )) + # Calculate mask on every sub symmetry one by one + for sub_symmetry_index, sub_symmetry_dtype in enumerate(self.dtypes)), + # Reduce from all true mask + torch.ones(size=self.data.size(), dtype=torch.bool), + ) + + @multimethod + def _prepare_position(self: Tensor, position: tuple[int, ...]) -> tuple[int, ...]: + indices: tuple[int, ...] = position + assert len(indices) == self.rank + assert all(0 <= index < edge.dimension for edge, index in zip(self.edges, indices)) + return indices + + @_prepare_position.register + def _(self: Tensor, position: tuple[slice, ...]) -> tuple[int, ...]: + index_by_name: dict[str, int] = {s.start: s.stop for s in position} + indices: tuple[int, ...] = tuple(index_by_name[name] for name in self.names) + assert len(indices) == self.rank + assert all(0 <= index < edge.dimension for edge, index in zip(self.edges, indices)) + return indices + + @_prepare_position.register + def _(self: Tensor, position: dict[str, int]) -> tuple[int, ...]: + indices: tuple[int, ...] = tuple(position[name] for name in self.names) + assert len(indices) == self.rank + assert all(0 <= index < edge.dimension for edge, index in zip(self.edges, indices)) + return indices + + def __getitem__(self: Tensor, position: tuple[int, ...] | tuple[slice, ...] | dict[str, int]) -> typing.Any: + """ + Get the element of the tensor + + Parameters + ---------- + position : tuple[int, ...] | tuple[slice, ...] | dict[str, int] + The position of the element, which could be either tuple of index directly or a map from edge name to the + index in the corresponding edge. + """ + indices: tuple[int, ...] = self._prepare_position(position) + return self.data[indices] + + def __setitem__(self: Tensor, position: tuple[int, ...] | tuple[slice, ...] | dict[str, int], + value: typing.Any) -> None: + """ + Set the element of the tensor + + Parameters + ---------- + position : tuple[int, ...] | tuple[slice, ...] | dict[str, int] + The position of the element, which could be either tuple of index directly or a map from edge name to the + index in the corresponding edge. + """ + indices = self._prepare_position(position) + if self.mask[indices]: + self.data[indices] = value + + def clear_symmetry(self: Tensor) -> Tensor: + """ + Clear all symmetry of this tensor. + + Returns + ------- + Tensor + The result tensor with symmetry cleared. + """ + # Mask must be generated again here + # pylint: disable=no-else-return + if any(self.fermion): + return Tensor( + names=self.names, + edges=tuple( + Edge( + fermion=(True,), + dtypes=(torch.bool,), + symmetry=(edge.parity,), + dimension=edge.dimension, + arrow=edge.arrow, + parity=edge.parity, + ) for edge in self.edges), + fermion=(True,), + dtypes=(torch.bool,), + data=self.data, + ) + else: + return Tensor( + names=self.names, + edges=tuple( + Edge( + fermion=(), + dtypes=(), + symmetry=(), + dimension=edge.dimension, + arrow=edge.arrow, + parity=edge.parity, + ) for edge in self.edges), + fermion=(), + dtypes=(), + data=self.data, + ) + + def randn_(self: Tensor, mean: float = 0., std: float = 1.) -> Tensor: + """ + Fill the tensor with random number in normal distribution. + + Parameters + ---------- + mean, std : float + The parameter of normal distribution. + + Returns + ------- + Tensor + Return this tensor itself. + """ + self.data.normal_(mean, std) + torch.where(self.mask, self.data, torch.zeros([], dtype=self.dtype), out=self.data) + return self + + def rand_(self: Tensor, low: float = 0., high: float = 1.) -> Tensor: + """ + Fill the tensor with random number in uniform distribution. + + Parameters + ---------- + low, high : float + The parameter of uniform distribution. + + Returns + ------- + Tensor + Return this tensor itself. + """ + self.data.uniform_(low, high) + torch.where(self.mask, self.data, torch.zeros([], dtype=self.dtype), out=self.data) + return self + + def same_type_with(self: Tensor, other: Tensor) -> bool: + """ + Check whether two tensor has the same type, that is to say they share the same symmetry type. + """ + return self.fermion == other.fermion and self.dtypes == other.dtypes + + def same_shape_with(self: Tensor, other: Tensor, *, allow_transpose: bool = True) -> bool: + """ + Check whether two tensor has the same shape, that is to say the only differences between them are transpose and + data difference. + """ + if not self.same_type_with(other): + return False + # pylint: disable=no-else-return + if allow_transpose: + return dict(zip(self.names, self.edges)) == dict(zip(other.names, other.edges)) + else: + return self.names == other.names and self.edges == other.edges + + def edge_rename(self: Tensor, name_map: dict[str, str]) -> Tensor: + """ + Rename edge name for this tensor. + + Parameters + ---------- + name_map : dict[str, str] + The name map to be used in renaming edge name. + + Returns + ------- + Tensor + The tensor with names renamed. + """ + return Tensor( + names=tuple(name_map.get(name, name) for name in self.names), + edges=self.edges, + fermion=self.fermion, + dtypes=self.dtypes, + data=self.data, + mask=self.mask, + ) + + def transpose(self: Tensor, names: tuple[str, ...]) -> Tensor: + """ + Transpose the tensor outplace. + + Parameters + ---------- + names : tuple[str, ...] + The new edge order identified by edge names. + + Returns + ------- + Tensor + The transpsoe tensor. + """ + if names == self.names: + return self + assert len(names) == len(self.names) + assert set(names) == set(self.names) + before_by_after = tuple(self.names.index(name) for name in names) + after_by_before = tuple(names.index(name) for name in self.names) + data = self.data.permute(before_by_after) + mask = self.mask.permute(before_by_after) + if any(self.fermion): + # It is fermionic tensor + parities_before_transpose = tuple( + _utility.unsqueeze(edge.parity, current_index, self.rank) + for current_index, edge in enumerate(self.edges)) + # Generate parity by xor all inverse pairs + parity = functools.reduce( + torch.logical_xor, + ( + torch.logical_and(parities_before_transpose[i], parities_before_transpose[j]) + # Loop every 0 <= i < j < rank + for j in range(self.rank) + for i in range(0, j) + if after_by_before[i] > after_by_before[j]), + torch.zeros([], dtype=torch.bool)) + # parity True -> -x + # parity False -> +x + data = torch.where(parity, -data, +data) + return Tensor( + names=names, + edges=tuple(self.edges[index] for index in before_by_after), + fermion=self.fermion, + dtypes=self.dtypes, + data=data, + mask=mask, + ) + + def reverse_edge( + self: Tensor, + reversed_names: set[str], + apply_parity: bool = False, + parity_exclude_names: set[str] | None = None, + ) -> Tensor: + """ + Reverse some edge in the tensor. + + Parameters + ---------- + reversed_names : set[str] + The edge names of those edges which will be reversed + apply_parity : bool, default=False + Whether to apply parity caused by reversing edge, since reversing edge will generate half a sign. + parity_exclude_names : set[str], optional + The name of edges in the different behavior other than default set by apply_parity. + + Returns + ------- + Tensor + The tensor with edges reversed. + """ + if not any(self.fermion): + return self + if parity_exclude_names is None: + parity_exclude_names = set() + assert all(name in self.names for name in reversed_names) + assert all(name in reversed_names for name in parity_exclude_names) + # Parity is xor of all valid reverse parity + parity = functools.reduce( + torch.logical_xor, + ( + _utility.unsqueeze(edge.parity, current_index, self.rank) + # Loop over all edge + for current_index, [name, edge] in enumerate(zip(self.names, self.edges)) + # Check if this edge is reversed and if this edge will be applied parity + if (name in reversed_names) and (apply_parity ^ (name in parity_exclude_names))), + torch.zeros([], dtype=torch.bool), + ) + data = torch.where(parity, -self.data, +self.data) + return Tensor( + names=self.names, + edges=tuple( + Edge( + fermion=edge.fermion, + dtypes=edge.dtypes, + symmetry=edge.symmetry, + dimension=edge.dimension, + arrow=not edge.arrow if self.names[current_index] in reversed_names else edge.arrow, + parity=edge.parity, + ) for current_index, edge in enumerate(self.edges)), + fermion=self.fermion, + dtypes=self.dtypes, + data=data, + mask=self.mask, + ) + + @staticmethod + def _split_edge_get_name_group( + name: str, + split_map: dict[str, tuple[tuple[str, Edge], ...]], + ) -> list[str]: + split_group: tuple[tuple[str, Edge], ...] | None = split_map.get(name, None) + # pylint: disable=no-else-return + if split_group is None: + return [name] + else: + return [new_name for new_name, _ in split_group] + + @staticmethod + def _split_edge_get_edge_group( + name: str, + edge: Edge, + split_map: dict[str, tuple[tuple[str, Edge], ...]], + ) -> list[Edge]: + split_group: tuple[tuple[str, Edge], ...] | None = split_map.get(name, None) + # pylint: disable=no-else-return + if split_group is None: + return [edge] + else: + return [new_edge for _, new_edge in split_group] + + def split_edge( + self: Tensor, + split_map: dict[str, tuple[tuple[str, Edge], ...]], + apply_parity: bool = False, + parity_exclude_names: set[str] | None = None, + ) -> Tensor: + """ + Split some edges in this tensor. + + Parameters + ---------- + split_map : dict[str, tuple[tuple[str, Edge], ...]] + The edge splitting plan. + apply_parity : bool, default=False + Whether to apply parity caused by splitting edge, since splitting edge will generate half a sign. + parity_exclude_names : set[str], optional + The name of edges in the different behavior other than default set by apply_parity. + + Returns + ------- + Tensor + The tensor with edges splitted. + """ + if parity_exclude_names is None: + parity_exclude_names = set() + # Check the edge to be splitted can be merged by result edges. + assert all( + self.edge_by_name(name) == Edge.merge_edges( + tuple(new_edge for _, new_edge in split_result), + fermion=self.fermion, + dtypes=self.dtypes, + arrow=self.edge_by_name(name).arrow, + ) for name, split_result in split_map.items()) + assert all(name in split_map for name in parity_exclude_names) + # Calculate the result components + names: tuple[str, ...] = tuple( + # Convert the list generated by reduce to tuple + functools.reduce( + # Concat list + operator.add, + # If name in split_map, use the new names list, otherwise use name itself as a length-1 list + (Tensor._split_edge_get_name_group(name, split_map) for name in self.names), + # Reduce from [] to concat all list + [], + )) + edges: tuple[Edge, ...] = tuple( + # Convert the list generated by reduce to tuple + functools.reduce( + # Concat list + operator.add, + # If name in split_map, use the new edges list, otherwise use the edge itself as a length-1 list + (Tensor._split_edge_get_edge_group(name, edge, split_map) for name, edge in zip(self.names, self.edges) + ), + # Reduce from [] to concat all list + [], + )) + new_size = [edge.dimension for edge in edges] + data = self.data.reshape(new_size) + mask = self.mask.reshape(new_size) + + # Apply parity + if any(self.fermion): + # It is femionic tensor, parity need to be applied + new_rank = len(names) + # Parity is xor of all valid splitting parity + parity = functools.reduce( + torch.logical_xor, + ( + # Apply the parity for this splitting group here + # It is need to perform a total transpose on this split group + # {sum 0<=i tuple[str, ...]: + reversed_names: list[str] = [] + for name in reversed(self.names): + found_in_merge_map: tuple[str, tuple[str, ...]] | None = next( + ((new_name, old_names) for new_name, old_names in merge_map.items() if name in old_names), None) + if found_in_merge_map is None: + # This edge will not be merged + reversed_names.append(name) + else: + new_name, old_names = found_in_merge_map + # This edge will be merged + if name == old_names[-1]: + # Add new edge only if it is the last edge + reversed_names.append(new_name) + # Some edge is merged from no edges, it should be considered + for new_name, old_names in merge_map.items(): + if not old_names: + reversed_names.append(new_name) + return tuple(reversed(reversed_names)) + + @staticmethod + def _merge_edge_get_name_group(name: str, merge_map: dict[str, tuple[str, ...]]) -> list[str]: + merge_group: tuple[str, ...] | None = merge_map.get(name, None) + # pylint: disable=no-else-return + if merge_group is None: + return [name] + else: + return list(merge_group) + + def merge_edge( + self: Tensor, + merge_map: dict[str, tuple[str, ...]], + apply_parity: bool = False, + parity_exclude_names: set[str] | None = None, + *, + merge_arrow: dict[str, bool] | None = None, + names: tuple[str, ...] | None = None, + ) -> Tensor: + """ + Merge some edges in this tensor. + + Parameters + ---------- + merge_map : dict[str, tuple[str, ...]] + The edge merging plan. + apply_parity : bool, default=False + Whether to apply parity caused by merging edge, since merging edge will generate half a sign. + parity_exclude_names : set[str], optional + The name of edges in the different behavior other than default set by apply_parity. + merge_arrow : dict[str, bool], optional + For merging edge from zero edges, arrow cannot be identified automatically, it requires user set manually. + names : tuple[str, ...], optional + The edge order of the result, sometimes user may want to specify it manually. + + Returns + ------- + Tensor + The tensor with edges merged. + """ + # pylint: disable=too-many-locals + if parity_exclude_names is None: + parity_exclude_names = set() + if merge_arrow is None: + merge_arrow = {} + assert all(all(old_name in self.names for old_name in old_names) for _, old_names in merge_map.items()) + assert all(name in merge_map for name in parity_exclude_names) + # Two steps: 1. Transpose 2. Merge + if names is None: + names = self._merge_edge_get_names(merge_map) + transposed_names: tuple[str, ...] = tuple( + # Convert the list generated by reduce to tuple + functools.reduce( + # Concat list + operator.add, + # If name in merge_map, use the old names list, otherwise use name itself as a length-1 list + (Tensor._merge_edge_get_name_group(name, merge_map) for name in names), + # Reduce from [] to concat all list + [], + )) + transposed_tensor = self.transpose(transposed_names) + # Prepare a name to index map, since we need to look up it frequently later. + transposed_name_map = {name: index for index, name in enumerate(transposed_tensor.names)} + edges = tuple( + # If name is created by merging, call Edge.merge_edges to get the merged edge, otherwise get it directly + # from tranposed_tensor. + Edge.merge_edges( + edges=tuple(transposed_tensor.edges[transposed_name_map[old_name]] + for old_name in merge_map[name]), + fermion=self.fermion, + dtypes=self.dtypes, + arrow=merge_arrow.get(name, None), + # If merging edge from zero edge, arrow need to be set manually + ) if name in merge_map else transposed_tensor.edges[transposed_name_map[name]] + # Loop over names + for name in names) + transposed_data = transposed_tensor.data + transposed_mask = transposed_tensor.mask + + # Apply parity + if any(self.fermion): + # It is femionic tensor, parity need to be applied + # Parity is xor of all valid merging parity + parity = functools.reduce( + torch.logical_xor, + ( + # Apply the parity for this merging group here + # It is need to perform a total transpose on this merging group + # {sum 0<=i Tensor: + """ + Contract two tensors. + + Parameters + ---------- + other : Tensor + Another tensor to be contracted. + contract_pairs : set[tuple[str, str]] + The pairs of edges to be contract between two tensors. + fuse_names : set[str], optional + The set of edges to be fuses. + + Returns + ------- + Tensor + The result contracted by two tensors. + """ + # pylint: disable=too-many-locals + # Only same type tensor can be contracted. + assert self.same_type_with(other) + + if fuse_names is None: + fuse_names = set() + # Fuse name should not have any symmetry + assert all( + all(_utility.zero_symmetry(sub_symmetry) + for sub_symmetry in self.edge_by_name(fuse_name).symmetry) + for fuse_name in fuse_names) + + # Alias tensor + tensor_1: Tensor = self + tensor_2: Tensor = other + + # Check if contract edge and fuse edge compatible + assert all(tensor_1.edge_by_name(name) == tensor_2.edge_by_name(name) for name in fuse_names) + assert all( + tensor_1.edge_by_name(name_1).conjugate() == tensor_2.edge_by_name(name_2) + for name_1, name_2 in contract_pairs) + + # All tensor edges merged to three part: fuse edge, contract edge, free edge + + # Contract of tensor has 5 step: + # 1. reverse arrow + # reverse all free edge and fuse edge to arrow False, without parity apply. + # reverse contract edge to two arrow: False(tensor_2) and True(tensor_1), only apply parity to one tensor. + # 2. merge edge + # merge all edge in the same part to one edge, only apply parity to contract edge of one tensor + # free edge and fuse edge will not be applied parity. + # 3. contract matrix + # call matrix multiply + # 4. split edge + # split edge merged in step 2, without apply parity + # 5. reverse arrow + # reverse arrow reversed in step 1, without parity apply + + # Step 1 + contract_names_1: set[str] = {name_1 for name_1, name_2 in contract_pairs} + contract_names_2: set[str] = {name_2 for name_1, name_2 in contract_pairs} + arrow_true_names_1: set[str] = {name for name, edge in zip(tensor_1.names, tensor_1.edges) if edge.arrow} + arrow_true_names_2: set[str] = {name for name, edge in zip(tensor_2.names, tensor_2.edges) if edge.arrow} + + # tensor 1: contract_names & arrow_false | not contract_names & arrow_true -> contract_names ^ arrow_true + tensor_1 = tensor_1.reverse_edge(contract_names_1 ^ arrow_true_names_1, False, + contract_names_1 - arrow_true_names_1) + tensor_2 = tensor_2.reverse_edge(arrow_true_names_2, False, set()) + + # Step 2 + free_edges_1: tuple[tuple[str, Edge], ...] = tuple((name, edge) + for name, edge in zip(tensor_1.names, tensor_1.edges) + if name not in fuse_names and name not in contract_names_1) + free_names_1: tuple[str, ...] = tuple(name for name, _ in free_edges_1) + free_edges_2: tuple[tuple[str, Edge], ...] = tuple((name, edge) + for name, edge in zip(tensor_2.names, tensor_2.edges) + if name not in fuse_names and name not in contract_names_2) + free_names_2: tuple[str, ...] = tuple(name for name, _ in free_edges_2) + # Check which tensor is bigger, and use it to determine the fuse and contract edge order. + ordered_fuse_edges: tuple[tuple[str, Edge], ...] + ordered_fuse_names: tuple[str, ...] + ordered_contract_names_1: tuple[str, ...] + ordered_contract_names_2: tuple[str, ...] + if tensor_1.data.nelement() > tensor_2.data.nelement(): + # Tensor 1 larger, fit by tensor 1 + ordered_fuse_edges = tuple( + (name, edge) for name, edge in zip(tensor_1.names, tensor_1.edges) if name in fuse_names) + ordered_fuse_names = tuple(name for name, _ in ordered_fuse_edges) + + # pylint: disable=unnecessary-comprehension + contract_names_map = {name_1: name_2 for name_1, name_2 in contract_pairs} + ordered_contract_names_1 = tuple(name for name in tensor_1.names if name in contract_names_1) + ordered_contract_names_2 = tuple(contract_names_map[name] for name in contract_names_1) + else: + # Tensor 2 larger, fit by tensor 2 + ordered_fuse_edges = tuple( + (name, edge) for name, edge in zip(tensor_2.names, tensor_2.edges) if name in fuse_names) + ordered_fuse_names = tuple(name for name, _ in ordered_fuse_edges) + + contract_names_map = {name_2: name_1 for name_1, name_2 in contract_pairs} + ordered_contract_names_2 = tuple(name for name in tensor_2.names if name in contract_names_2) + ordered_contract_names_1 = tuple(contract_names_map[name] for name in contract_names_2) + + put_contract_1_right: bool = next( + (name in contract_names_1 for name in reversed(tensor_1.names) if name not in fuse_names), True) + put_contract_2_right: bool = next( + (name in contract_names_2 for name in reversed(tensor_2.names) if name not in fuse_names), False) + + tensor_1 = tensor_1.merge_edge( + { + "Free_1": free_names_1, + "Contract_1": ordered_contract_names_1, + "Fuse_1": ordered_fuse_names, + }, + False, + {"Contract_1"}, + merge_arrow={ + "Free_1": False, + "Contract_1": True, + "Fuse_1": False, + }, + names=("Fuse_1", "Free_1", "Contract_1") if put_contract_1_right else ("Fuse_1", "Contract_1", "Free_1"), + ) + tensor_2 = tensor_2.merge_edge( + { + "Free_2": free_names_2, + "Contract_2": ordered_contract_names_2, + "Fuse_2": ordered_fuse_names, + }, + False, + set(), + merge_arrow={ + "Free_2": False, + "Contract_2": False, + "Fuse_2": False, + }, + names=("Fuse_2", "Free_2", "Contract_2") if put_contract_2_right else ("Fuse_2", "Contract_2", "Free_2"), + ) + # C[fuse, free1, free2] = A[fuse, free1 contract] B[fuse, contract free2] + assert tensor_1.edge_by_name("Fuse_1") == tensor_2.edge_by_name("Fuse_2") + assert tensor_1.edge_by_name("Contact_1").conjugate() == tensor_2.edge_by_name("Contract_2") + + # Step 3 + # The standard arrow is + # (0, False, True) (0, False, False) + # aka: (a b) (c d) (c+ b+) = (a d) + # since: EPR pair order is (False True) + # put_contract_1_right should be True + # put_contract_2_right should be False + # Every mismatch generate a sign + # Total sign is + # (!put_contract_1_right) ^ (put_contract_2_right) = put_contract_1_right == put_contract_2_right + data = torch.einsum( + "b" + ("ic" if put_contract_1_right else "ci") + ",b" + ("jc" if put_contract_2_right else "cj") + "->bij", + tensor_1.data, tensor_2.data) + if put_contract_1_right == put_contract_2_right: + data = torch.where(tensor_2.edge_by_name("Free_2").parity.reshape([1, 1, -1]), -data, +data) + tensor = Tensor( + names=("Fuse", "Free_1", "Free_2"), + edges=(tensor_1.edge_by_name("Fuse_1"), tensor_1.edge_by_name("Free_1"), tensor_2.edge_by_name("Free_2")), + fermion=self.fermion, + dtypes=self.dtypes, + data=data, + ) + + # Step 4 + tensor = tensor.split_edge({ + "Fuse": ordered_fuse_edges, + "Free_1": free_edges_1, + "Free_2": free_edges_2 + }, False, set()) + + # Step 5 + tensor = tensor.reverse_edge( + (arrow_true_names_1 - contract_names_1) | (arrow_true_names_2 - contract_names_2), + False, + set(), + ) + + return tensor + + def _trace_group_edge( + self: Tensor, + trace_pairs: set[tuple[str, str]], + fuse_names: dict[str, tuple[str, str]], + ) -> tuple[ + tuple[str, ...], + tuple[str, ...], + tuple[str, ...], + tuple[str, ...], + tuple[str, ...], + tuple[str, ...], + tuple[int, ...], + tuple[int, ...], + ]: + # pylint: disable=too-many-locals + # pylint: disable=unnecessary-comprehension + trace_map = { + old_name_1: old_name_2 for old_name_1, old_name_2 in trace_pairs + } | { + old_name_2: old_name_1 for old_name_1, old_name_2 in trace_pairs + } + fuse_map = { + old_name_1: (old_name_2, new_name) for new_name, [old_name_1, old_name_2] in fuse_names.items() + } | { + old_name_2: (old_name_1, new_name) for new_name, [old_name_1, old_name_2] in fuse_names.items() + } + reversed_trace_names_1: list[str] = [] + reversed_trace_names_2: list[str] = [] + reversed_fuse_names_1: list[str] = [] + reversed_fuse_names_2: list[str] = [] + reversed_free_names: list[str] = [] + reversed_fuse_names_result: list[str] = [] + reversed_free_index: list[int] = [] + reversed_fuse_index_result: list[int] = [] + added_names: set[str] = set() + for index, name in zip(reversed(range(self.rank)), reversed(self.names)): + if name not in added_names: + trace_name: str | None = trace_map.get(name, None) + fuse_name: tuple[str, str] | None = fuse_map.get(name, None) + if trace_name is not None: + reversed_trace_names_2.append(name) + reversed_trace_names_1.append(trace_name) + added_names.add(trace_name) + elif fuse_name is not None: + # fuse_name = another old name, new name + reversed_fuse_names_2.append(name) + reversed_fuse_names_1.append(fuse_name[0]) + added_names.add(fuse_name[0]) + reversed_fuse_names_result.append(fuse_name[1]) + reversed_fuse_index_result.append(index) + else: + reversed_free_names.append(name) + reversed_free_index.append(index) + return ( + tuple(reversed(reversed_trace_names_1)), + tuple(reversed(reversed_trace_names_2)), + tuple(reversed(reversed_fuse_names_1)), + tuple(reversed(reversed_fuse_names_2)), + tuple(reversed(reversed_free_names)), + tuple(reversed(reversed_fuse_names_result)), + tuple(reversed(reversed_free_index)), + tuple(reversed(reversed_fuse_index_result)), + ) + + def trace( + self: Tensor, + trace_pairs: set[tuple[str, str]], + fuse_names: dict[str, tuple[str, str]] | None = None, + ) -> Tensor: + """ + Trace a tensor. + + Parameters + ---------- + trace_pairs : set[tuple[str, str]] + The pairs of edges to be traced + fuse_names : dict[str, tuple[str, str]] + The edges to be fused. + + Returns + ------- + Tensor + The traced tensor. + """ + # pylint: disable=too-many-locals + if fuse_names is None: + fuse_names = {} + # Fuse names should not have any symmetry + assert all( + all(_utility.zero_symmetry(sub_symmetry) + for sub_symmetry in self.edge_by_name(old_name_1).symmetry) + for new_name, [old_name_1, old_name_2] in fuse_names.items()) + # Fuse names should share the same edges + assert all( + self.edge_by_name(old_name_1) == self.edge_by_name(old_name_2) + for new_name, [old_name_1, old_name_2] in fuse_names.items()) + # Trace edges should be compatible + assert all( + self.edge_by_name(old_name_1).conjugate() == self.edge_by_name(old_name_2) + for old_name_1, old_name_2 in trace_pairs) + + # Split trace pairs and fuse names to two part before main part of trace. + [ + trace_names_1, + trace_names_2, + fuse_names_1, + fuse_names_2, + free_names, + fuse_names_result, + free_index, + fuse_index_result, + ] = self._trace_group_edge(trace_pairs, fuse_names) + + # Make alias + tensor = self + + # Tensor edges merged to 5 parts: fuse edge 1, fuse edge 2, trace edge 1, trace edge 2, free edge + # Trace contains 5 step: + # 1. reverse all arrow to False except trace edge 1 to be True, only apply parity to one of two trace edge + # 2. merge all edge to 5 part, only apply parity to one of two trace edge + # 3. trace trivial tensor + # 4. split edge merged in step 2, without apply parity + # 5. reverse arrow reversed in step 1, without apply parity + + # Step 1 + arrow_true_names = {name for name, edge in zip(tensor.names, tensor.edges) if edge.arrow} + unordered_trace_names_1 = set(trace_names_1) + tensor = tensor.reverse_edge(unordered_trace_names_1 ^ arrow_true_names, False, + unordered_trace_names_1 - arrow_true_names) + + # Step 2 + free_edges: tuple[tuple[str, Edge], ...] = tuple( + (name, tensor.edges[index]) for name, index in zip(free_names, free_index)) + fuse_edges_result: tuple[tuple[str, Edge], ...] = tuple( + (name, tensor.edges[index]) for name, index in zip(fuse_names_result, fuse_index_result)) + tensor = tensor.merge_edge( + { + "Trace_1": trace_names_1, + "Trace_2": trace_names_2, + "Fuse_1": fuse_names_1, + "Fuse_2": fuse_names_2, + "Free": free_names, + }, + False, + {"Trace_1"}, + merge_arrow={ + "Trace_1": True, + "Trace_2": False, + "Fuse_1": False, + "Fuse_2": False, + "Free": False, + }, + names=("Trace_1", "Trace_2", "Fuse_1", "Fuse_2", "Free"), + ) + # B[fuse, free] = A[trace, trace, fuse, fuse, free] + assert tensor.edges[2] == tensor.edges[3] + assert tensor.edges[0].conjugate() == tensor.edges[1] + + # Step 3 + # As tested, the order of edges in this einsum is not important + # ttffi->fi, fftti->fi, ffitt->fi, ttiff->if, ittff->if, ifftt->if + data = torch.einsum("ttffi->fi", tensor.data) + tensor = Tensor( + names=("Fuse", "Free"), + edges=(tensor.edges[2], tensor.edges[4]), + fermion=self.fermion, + dtypes=self.dtypes, + data=data, + ) + + # Step 4 + tensor = tensor.split_edge({ + "Fuse": fuse_edges_result, + "Free": free_edges, + }, False, set()) + + # Step 5 + tensor = tensor.reverse_edge( + # Free edge with arrow true + {name for name in free_names if name in arrow_true_names} | + # New edge from fused edge with arrow true + {new_name for old_name, new_name in zip(fuse_names_1, fuse_names_result) if old_name in arrow_true_names}, + False, + set(), + ) + + return tensor + + def conjugate(self: Tensor, trivial_metric: bool = False) -> Tensor: + """ + Get the conjugate of this tensor. + + Parameters + ---------- + trivial_metric : bool, default=False + Fermionic tensor in network may result in non positive definite metric, set this trivial_metric to True to + ensure the metric to be positive, but it breaks the associative law with tensor contract. + + Returns + ------- + Tensor + The conjugated tensor. + """ + data = torch.conj(self.data) + + # Usually, only a full transpose sign is applied. + # If trivial_metric is set True, parity in edges with arrow True is also applied. + + # Apply parity + if any(self.fermion): + # It is femionic tensor, parity need to be applied + + # Parity is parity generated from a full transpose + # {sum 0<=i Edge: + # pylint: disable=invalid-name + m, n = matrix.size() + assert edge.dimension == m + argmax = torch.argmax(matrix, dim=0) + assert argmax.size == (n,) + return Edge( + fermion=edge.fermion, + dtypes=edge.dtypes, + symmetry=tuple(_utility.neg_symmetry(sub_symmetry[argmax]) for sub_symmetry in edge.symmetry), + dimension=n, + arrow=arrow, + parity=edge.parity[argmax], + ) + + def svd( + self: Tensor, + free_names_u: set[str], + common_name_u: str, + common_name_v: str, + singular_name_u: str, + singular_name_v: str, + cut: int = -1, + fuse_names: set[str] | None = None, + ) -> tuple[Tensor, Tensor, Tensor]: + """ + SVD decomposition a tensor. + + Parameters + ---------- + free_names_u : set[str] + Free names in U tensor of the result of SVD. + common_name_u, common_name_v, singular_name_u, singular_name_v : str + The name of generated edges. + cut : int, default=-1 + The cut for the singular values. + fuse_names : set[str], optional + The names of fuse edges. + + Returns + ------- + tuple[Tensor, Tensor, Tensor] + U, S, V tensor, the result of SVD. + """ + # pylint: disable=too-many-arguments + # pylint: disable=too-many-locals + if fuse_names is None: + fuse_names = set() + # Fuse name should not have any symmetry + assert all( + all(_utility.zero_symmetry(sub_symmetry) + for sub_symmetry in self.edge_by_name(fuse_name).symmetry) + for fuse_name in fuse_names) + + free_names_v = {name for name in self.names if name not in free_names_u and name not in fuse_names} + + assert all(name in self.names for name in free_names_u) + assert all(name in self.names for name in fuse_names) + assert fuse_names & free_names_u == set() + assert common_name_u not in free_names_u | fuse_names + assert common_name_v not in free_names_v | fuse_names + + arrow_true_names = {name for name, edge in zip(self.names, self.edges) if edge.arrow} + + tensor = self.reverse_edge(arrow_true_names, False, set()) + + ordered_fuse_edges: tuple[tuple[str, Edge], ...] = tuple( + (name, edge) for name, edge in zip(tensor.names, tensor.edges) if name in fuse_names) + ordered_free_edges_u: tuple[tuple[str, Edge], ...] = tuple( + (name, edge) for name, edge in zip(tensor.names, tensor.edges) if name in free_names_u) + ordered_free_edges_v: tuple[tuple[str, Edge], ...] = tuple( + (name, edge) for name, edge in zip(tensor.names, tensor.edges) if name in free_names_v) + ordered_fuse_names: tuple[str, ...] = tuple(name for name, _ in ordered_fuse_edges) + ordered_free_names_u: tuple[str, ...] = tuple(name for name, _ in ordered_free_edges_u) + ordered_free_names_v: tuple[str, ...] = tuple(name for name, _ in ordered_free_edges_v) + + put_v_right = next((name in free_names_v for name in reversed(tensor.names) if name not in fuse_names), True) + tensor = tensor.merge_edge( + { + "SVD_F": ordered_fuse_names, + "SVD_U": ordered_free_names_u, + "SVD_V": ordered_free_names_v + }, + False, + set(), + merge_arrow={ + "SVD_F": False, + "SVD_U": False, + "SVD_V": False + }, + names=("SVD_F", "SVD_U", "SVD_V") if put_v_right else ("SVD_F", "SVD_V", "SVD_U"), + ) + + data_1, data_s, data_2 = torch.linalg.svd(tensor.data, full_matrices=False) + if cut != -1: + data_1 = data_1[:, :, :cut] + data_s = data_s[:, :cut] + data_2 = data_2[:, :cut, :] + data_s = torch.diag_embed(data_s) + + fuse_edge = tensor.edges[0] + free_edge_1 = tensor.edges[1] + common_edge_1 = Tensor._guess_edge(torch.sum(torch.abs(data_1), dim=0), free_edge_1, True) + tensor_1 = Tensor( + names=("SVD_F", "SVD_U", common_name_u) if put_v_right else ("SVD_F", "SVD_V", common_name_v), + edges=(fuse_edge, free_edge_1, common_edge_1), + fermion=self.fermion, + dtypes=self.dtypes, + data=data_1, + ) + free_edge_2 = tensor.edges[2] + common_edge_2 = Tensor._guess_edge(torch.sum(torch.abs(data_2.transpose(0, 1)), dim=0), free_edge_2, False) + tensor_2 = Tensor( + names=("SVD_F", common_name_v, "SVD_V") if put_v_right else ("SVD_F", common_name_u, "SVD_U"), + edges=(fuse_edge, common_edge_2, free_edge_2), + fermion=self.fermion, + dtypes=self.dtypes, + data=data_2, + ) + assert common_edge_1.conjugate() == common_edge_2 + tensor_s = Tensor( + names=("SVD_F", singular_name_u, singular_name_v) if put_v_right else + ("SVD_F", singular_name_v, singular_name_u), + edges=(common_edge_2, common_edge_1), + fermion=self.fermion, + dtypes=self.dtypes, + data=data_s, + ) + + tensor_u = tensor_1 if put_v_right else tensor_2 + tensor_v = tensor_2 if put_v_right else tensor_1 + + tensor_u = tensor_u.split_edge({"SVD_U": ordered_free_edges_u, "SVD_F": ordered_fuse_edges}, False, set()) + tensor_v = tensor_v.split_edge({"SVD_V": ordered_free_edges_v, "SVD_F": ordered_fuse_edges}, False, set()) + tensor_s = tensor_s.split_edge({"SVD_F": ordered_fuse_edges}, False, set()) + + tensor_u = tensor_u.reverse_edge(arrow_true_names & (free_names_u | fuse_names), False, set()) + tensor_v = tensor_v.reverse_edge(arrow_true_names & (free_names_v | fuse_names), False, set()) + tensor_s = tensor_s.reverse_edge(arrow_true_names & fuse_names, False, set()) + + return tensor_u, tensor_s, tensor_v + + def qr(self: Tensor, + free_names_direction: str, + free_names: set[str], + common_name_q: str, + common_name_r: str, + fuse_names: set[str] | None = None) -> tuple[Tensor, Tensor]: + """ + QR decomposition on this tensor. + + Parameters + ---------- + free_names_direction : 'Q' | 'q' | 'R' | 'r' + Specify which direction the free_names will set + free_names : set[str] + The names of free edges after QR decomposition. + common_name_q, common_name_r : str + The names of edges created by QR decomposition. + fuse_names : set[str], optional + The names of fuse edges + + Returns + ------- + tuple[Tensor, Tensor] + Tensor Q and R, the result of QR decomposition. + """ + # pylint: disable=invalid-name + # pylint: disable=too-many-arguments + # pylint: disable=too-many-locals + if fuse_names is None: + fuse_names = set() + # Fuse name should not have any symmetry + assert all( + all(_utility.zero_symmetry(sub_symmetry) + for sub_symmetry in self.edge_by_name(fuse_name).symmetry) + for fuse_name in fuse_names) + + if free_names_direction in {'Q', 'q'}: + free_names_q = free_names + free_names_r = {name for name in self.names if name not in free_names and name not in fuse_names} + elif free_names_direction in {'R', 'r'}: + free_names_r = free_names + free_names_q = {name for name in self.names if name not in free_names and name not in fuse_names} + + assert all(name in self.names for name in free_names) + assert all(name in self.names for name in fuse_names) + assert fuse_names & free_names == set() + assert common_name_q not in free_names_q | fuse_names + assert common_name_r not in free_names_r | fuse_names + + arrow_true_names = {name for name, edge in zip(self.names, self.edges) if edge.arrow} + + tensor = self.reverse_edge(arrow_true_names, False, set()) + + ordered_fuse_edges: tuple[tuple[str, Edge], ...] = tuple( + (name, edge) for name, edge in zip(tensor.names, tensor.edges) if name in fuse_names) + ordered_free_edges_q: tuple[tuple[str, Edge], ...] = tuple( + (name, edge) for name, edge in zip(tensor.names, tensor.edges) if name in free_names_q) + ordered_free_edges_r: tuple[tuple[str, Edge], ...] = tuple( + (name, edge) for name, edge in zip(tensor.names, tensor.edges) if name in free_names_r) + ordered_fuse_names: tuple[str, ...] = tuple(name for name, _ in ordered_fuse_edges) + ordered_free_names_q: tuple[str, ...] = tuple(name for name, _ in ordered_free_edges_q) + ordered_free_names_r: tuple[str, ...] = tuple(name for name, _ in ordered_free_edges_r) + + # pytorch does not provide LQ, so always put r right here + tensor = tensor.merge_edge( + { + "QR_F": ordered_fuse_names, + "QR_Q": ordered_free_names_q, + "QR_R": ordered_free_names_r + }, + False, + set(), + merge_arrow={ + "QR_F": False, + "QR_Q": False, + "QR_R": False + }, + names=("QR_F", "QR_Q", "QR_R"), + ) + + data_q, data_r = torch.linalg.qr(tensor.data, mode="reduced") + + fuse_edge = tensor.edges[0] + free_edge_q = tensor.edges[1] + common_edge_q = Tensor._guess_edge(torch.sum(torch.abs(data_q), dim=0), free_edge_q, True) + tensor_q = Tensor( + names=("QR_F", "QR_Q", common_name_q), + edges=(fuse_edge, free_edge_q, common_edge_q), + fermion=self.fermion, + dtypes=self.dtypes, + data=data_q, + ) + free_edge_r = tensor.edges[2] + common_edge_r = Tensor._guess_edge(torch.sum(torch.abs(data_r.transpose(0, 1)), dim=0), free_edge_r, False) + tensor_r = Tensor( + names=("QR_F", common_name_r, "QR_R"), + edges=(fuse_edge, common_edge_r, free_edge_r), + fermion=self.fermion, + dtypes=self.dtypes, + data=data_r, + ) + assert common_edge_q.conjugate() == common_edge_r + + tensor_q = tensor_q.split_edge({"QR_Q": ordered_free_edges_q, "QR_F": ordered_fuse_edges}, False, set()) + tensor_r = tensor_r.split_edge({"QR_R": ordered_free_edges_r, "QR_F": ordered_fuse_edges}, False, set()) + + tensor_q = tensor_q.reverse_edge(arrow_true_names & (free_names_q | fuse_names), False, set()) + tensor_r = tensor_r.reverse_edge(arrow_true_names & (free_names_r | fuse_names), False, set()) + + return tensor_q, tensor_r + + def identity(self: Tensor, pairs: set[tuple[str, str]]) -> Tensor: + """ + Get the identity tensor with same shape to this tensor. + + Parameters + ---------- + pairs : set[tuple[str, str]] + The pair of edge names to specify the relation among edges to set identity tensor. + + Returns + ------- + Tensor + The result identity tensor. + """ + # The order of edges before setting identity should be (False True) + # Merge tensor directly to two edge, set identity and split it directly. + # When splitting, only apply parity to one part of edges + + # pylint: disable=unnecessary-comprehension + pairs_map = {name_1: name_2 for name_1, name_2 in pairs} | {name_2: name_1 for name_1, name_2 in pairs} + added_names: set[str] = set() + reversed_names_1: list[str] = [] + reversed_names_2: list[str] = [] + for name in reversed(self.names): + if name not in added_names: + another_name = pairs_map[name] + reversed_names_2.append(name) + reversed_names_1.append(another_name) + added_names.add(another_name) + names_1 = tuple(reversed(reversed_names_1)) + names_2 = tuple(reversed(reversed_names_2)) + # unordered_names_1 = set(names_1) + unordered_names_2 = set(names_2) + + arrow_true_names = {name for name, edge in zip(self.names, self.edges) if edge.arrow} + + # Two edges, arrow of two edges are (False, True) + tensor = self.reverse_edge(unordered_names_2 ^ arrow_true_names, False, unordered_names_2 - arrow_true_names) + + edges_1 = tuple((name, tensor.edge_by_name(name)) for name in names_1) + edges_2 = tuple((name, tensor.edge_by_name(name)) for name in names_2) + + tensor = tensor.merge_edge( + { + "Identity_1": names_1, + "Identity_2": names_2 + }, + False, + {"Identity_2"}, + merge_arrow={ + "Identity_1": False, + "Identity_2": True + }, + names=("Identity_1", "Identity_2"), + ) + + tensor = Tensor( + names=tensor.names, + edges=tensor.edges, + fermion=tensor.fermion, + dtypes=tensor.dtypes, + data=torch.eye(*tensor.data.size()), + mask=tensor.mask, + ) + + tensor = tensor.split_edge({"Identity_1": edges_1, "Identity_2": edges_2}, False, {"Identity_2"}) + + tensor = tensor.reverse_edge(unordered_names_2 ^ arrow_true_names, False, unordered_names_2 - arrow_true_names) + + return tensor + + def exponential(self: Tensor, pairs: set[tuple[str, str]]) -> Tensor: + """ + Get the exponential tensor of this tensor. + + Parameters + ---------- + pairs : set[tuple[str, str]] + The pair of edge names to specify the relation among edges to calculate exponential tensor. + + Returns + ------- + Tensor + The result exponential tensor. + """ + # The order of edges before setting exponential should be (False True) + # Merge tensor directly to two edge, set exponential and split it directly. + # When splitting, only apply parity to one part of edges + + unordered_names_1 = {name_1 for name_1, name_2 in pairs} + unordered_names_2 = {name_2 for name_1, name_2 in pairs} + if self.names and self.names[-1] in unordered_names_1: + unordered_names_1, unordered_names_2 = unordered_names_2, unordered_names_1 + # pylint: disable=unnecessary-comprehension + pairs_map = {name_1: name_2 for name_1, name_2 in pairs} | {name_2: name_1 for name_1, name_2 in pairs} + reversed_names_1: list[str] = [] + reversed_names_2: list[str] = [] + for name in reversed(self.names): + if name in unordered_names_2: + another_name = pairs_map[name] + reversed_names_2.append(name) + reversed_names_1.append(another_name) + names_1 = tuple(reversed(reversed_names_1)) + names_2 = tuple(reversed(reversed_names_2)) + + arrow_true_names = {name for name, edge in zip(self.names, self.edges) if edge.arrow} + + # Two edges, arrow of two edges are (False, True) + tensor = self.reverse_edge(unordered_names_2 ^ arrow_true_names, False, unordered_names_2 - arrow_true_names) + + edges_1 = tuple((name, tensor.edge_by_name(name)) for name in names_1) + edges_2 = tuple((name, tensor.edge_by_name(name)) for name in names_2) + + tensor = tensor.merge_edge( + { + "Exponential_1": names_1, + "Exponential_2": names_2 + }, + False, + {"Exponential_2"}, + merge_arrow={ + "Exponential_1": False, + "Exponential_2": True + }, + names=("Exponential_1", "Exponential_2"), + ) + + tensor = Tensor( + names=tensor.names, + edges=tensor.edges, + fermion=tensor.fermion, + dtypes=tensor.dtypes, + data=torch.linalg.matrix_exp(tensor.data), + mask=tensor.mask, + ) + + tensor = tensor.split_edge({"Exponential_1": edges_1, "Exponential_2": edges_2}, False, {"Exponential_2"}) + + tensor = tensor.reverse_edge(unordered_names_2 ^ arrow_true_names, False, unordered_names_2 - arrow_true_names) + + return tensor diff --git a/tests/test_compat.py b/tests/test_compat.py new file mode 100644 index 000000000..3bac86a7f --- /dev/null +++ b/tests/test_compat.py @@ -0,0 +1,113 @@ +import torch +import tat +from tat import compat as TAT + + +def test_edge_from_dimension(): + assert TAT.No.Edge(4) == tat.Edge(dimension=4) + assert TAT.Fermi.Edge(4) == tat.Edge(fermion=(True,), + symmetry=(torch.tensor([0, 0, 0, 0], dtype=torch.int),), + arrow=False) + assert TAT.Z2.Edge(4) == tat.Edge(symmetry=(torch.tensor([False, False, False, False]),)) + + +def test_edge_from_segments(): + assert TAT.Z2.Edge([ + (False, 2), + (True, 3), + ]) == tat.Edge(symmetry=(torch.tensor([False, False, True, True, True]),),) + assert TAT.Fermi.Edge([ + (-1, 1), + (0, 2), + (+1, 3), + ], True) == tat.Edge( + symmetry=(torch.tensor([-1, 0, 0, +1, +1, +1], dtype=torch.int),), + arrow=True, + fermion=(True,), + ) + assert TAT.FermiFermi.Edge([ + ((-1, -2), 1), + ((0, +1), 2), + ((+1, 0), 3), + ], True) == tat.Edge( + symmetry=( + torch.tensor([-1, 0, 0, +1, +1, +1], dtype=torch.int), + torch.tensor([-2, +1, +1, 0, 0, 0], dtype=torch.int), + ), + arrow=True, + fermion=(True, True), + ) + + +def test_edge_from_segments_without_dimension(): + assert TAT.Z2.Edge([False, True]) == tat.Edge(symmetry=(torch.tensor([False, True]),)) + assert TAT.Fermi.Edge([-1, 0, +1], True) == tat.Edge( + symmetry=(torch.tensor([-1, 0, +1], dtype=torch.int),), + arrow=True, + fermion=(True,), + ) + assert TAT.FermiFermi.Edge([ + (-1, -2), + (0, +1), + (+1, 0), + ], True) == tat.Edge( + symmetry=(torch.tensor([-1, 0, +1], dtype=torch.int), torch.tensor([-2, +1, 0], dtype=torch.int)), + arrow=True, + fermion=(True, True), + ) + + +def test_edge_from_tuple(): + assert TAT.FermiFermi.Edge(([ + ((-1, -2), 1), + ((0, +1), 2), + ((+1, 0), 3), + ], True)) == tat.Edge( + symmetry=( + torch.tensor([-1, 0, 0, +1, +1, +1], dtype=torch.int), + torch.tensor([-2, +1, +1, 0, 0, 0], dtype=torch.int), + ), + arrow=True, + fermion=(True, True), + ) + assert TAT.FermiFermi.Edge(([ + (-1, -2), + (0, +1), + (+1, 0), + ], True)) == tat.Edge( + symmetry=(torch.tensor([-1, 0, +1], dtype=torch.int), torch.tensor([-2, +1, 0], dtype=torch.int)), + arrow=True, + fermion=(True, True), + ) + + +def test_tensor(): + a = TAT.FermiZ2.D.Tensor(["i", "j"], [ + [(-1, False), (-1, True), (0, True), (0, False)], + [(+1, True), (+1, False), (0, False), (0, True)], + ]) + b = tat.Tensor( + ( + "i", + "j", + ), + ( + tat.Edge( + fermion=(True, False), + symmetry=( + torch.tensor([-1, -1, 0, 0], dtype=torch.int), + torch.tensor([False, True, True, False]), + ), + arrow=False, + ), + tat.Edge( + fermion=(True, False), + symmetry=( + torch.tensor([+1, +1, 0, 0], dtype=torch.int), + torch.tensor([True, False, False, True]), + ), + arrow=False, + ), + ), + ) + assert a.same_shape_with(b, allow_transpose=False) diff --git a/tests/test_create_tensor.py b/tests/test_create_tensor.py new file mode 100644 index 000000000..079609ffa --- /dev/null +++ b/tests/test_create_tensor.py @@ -0,0 +1,89 @@ +import torch +import tat + + +def test_create_tensor(): + a = tat.Tensor( + ( + "i", + "j", + ), + ( + tat.Edge(symmetry=(torch.tensor([False, False, True]),), fermion=(True,), arrow=True), + tat.Edge(symmetry=(torch.tensor([False, False, False, True, True]),), fermion=(True,), arrow=False), + ), + ) + assert a.rank == 2 + assert a.names == ("i", "j") + assert a.edges[0] == tat.Edge(symmetry=(torch.tensor([False, False, True]),), fermion=(True,), arrow=True) + assert a.edges[1] == tat.Edge(symmetry=(torch.tensor([False, False, False, True, True]),), + fermion=(True,), + arrow=False) + assert a.edges[0] == a.edge_by_name("i") + assert a.edges[1] == a.edge_by_name("j") + + +def test_tensor_get_set_item(): + a = tat.Tensor( + ( + "i", + "j", + ), + ( + tat.Edge(symmetry=(torch.tensor([False, False, True]),), fermion=(True,), arrow=True), + tat.Edge(symmetry=(torch.tensor([False, False, False, True, True]),), fermion=(True,), arrow=False), + ), + ) + a[{"i": 0, "j": 0}] = 1 + assert a[0, 0] == 1 + a["i":2, "j":3] = 2 + assert a[{"i": 2, "j": 3}] == 2 + a[2, 0] = 3 + assert a["i":2, "j":0] == 0 + + b = tat.Tensor( + ( + "i", + "j", + ), + ( + tat.Edge(symmetry=(torch.tensor([0, 0, -1]),), fermion=(False,)), + tat.Edge(symmetry=(torch.tensor([0, 0, 0, +1, +1]),), fermion=(False,)), + ), + ) + b[{"i": 0, "j": 0}] = 1 + assert b[0, 0] == 1 + b["i":2, "j":3] = 2 + assert b[{"i": 2, "j": 3}] == 2 + b[2, 0] = 3 + assert b["i":2, "j":0] == 0 + + +def test_create_randn_tensor(): + a = tat.Tensor( + ("i", "j"), + ( + tat.Edge(symmetry=(torch.tensor([False, True]),)), + tat.Edge(symmetry=(torch.tensor([False, True]),)), + ), + dtype=torch.float16, + ).randn_() + assert a.dtype == torch.float16 + assert a[0, 0] != 0 + assert a[1, 1] != 0 + assert a[0, 1] == 0 + assert a[1, 0] == 0 + + b = tat.Tensor( + ("i", "j"), + ( + tat.Edge(symmetry=(torch.tensor([False, False]), torch.tensor([0, -1]))), + tat.Edge(symmetry=(torch.tensor([False, False]), torch.tensor([0, +1]))), + ), + dtype=torch.float16, + ).randn_() + assert b.dtype == torch.float16 + assert b[0, 0] != 0 + assert b[1, 1] != 0 + assert b[0, 1] == 0 + assert b[1, 0] == 0 diff --git a/tests/test_edge.py b/tests/test_edge.py new file mode 100644 index 000000000..096973e71 --- /dev/null +++ b/tests/test_edge.py @@ -0,0 +1,33 @@ +import torch +from tat import Edge + + +def test_create_edge_and_basic(): + a = Edge(dimension=5) + assert a.arrow == False + assert a.dimension == 5 + b = Edge(symmetry=(torch.tensor([False, False, True, True]),)) + assert b.arrow == False + assert b.dimension == 4 + c = Edge(fermion=(False, True), symmetry=(torch.tensor([False, True]), torch.tensor([False, True])), arrow=True) + assert c.arrow == True + assert c.dimension == 2 + + +def test_edge_conjugate_and_equal(): + a = Edge(fermion=(False, True), symmetry=(torch.tensor([False, True]), torch.tensor([0, 1])), arrow=True) + b = Edge(fermion=(False, True), symmetry=(torch.tensor([False, True]), torch.tensor([0, -1])), arrow=False) + assert a.conjugate() == b + assert a != 2 + + +def test_repr(): + a = Edge(fermion=(False, True), symmetry=(torch.tensor([False, True]), torch.tensor([0, 1])), arrow=True) + repr_a = repr(a) + assert repr_a == "Edge(dimension=2, arrow=True, fermion=(False, True), symmetry=(tensor([False, True]), tensor([0, 1])))" + b = Edge(symmetry=(torch.tensor([False, True]), torch.tensor([0, 1]))) + repr_b = repr(b) + assert repr_b == "Edge(dimension=2, symmetry=(tensor([False, True]), tensor([0, 1])))" + c = Edge(dimension=4) + repr_c = repr(c) + assert repr_c == "Edge(dimension=4)"